diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index 0d87c33a..cd9f41d2 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -485,21 +485,31 @@ def build_dataloaders( ) if args.eval_data_path is not None or args.eval_hidden_states_path is not None: if args.eval_data_path is not None: - eval_dataset = Dataset.from_generator( - generator=safe_conversations_generator, - gen_kwargs={"file_path": args.eval_data_path}, - ) - eval_eagle3_dataset = build_eagle3_dataset( - eval_dataset, - tokenizer, - args.chat_template, - args.max_length, - is_vlm=args.is_vlm, - processor=processor, - num_proc=args.build_dataset_num_proc, - is_preformatted=args.is_preformatted, - train_only_last_turn=args.train_only_last_turn, + eval_cache_params_string = ( + f"{args.eval_data_path}-" + f"{args.max_length}-" + f"{args.chat_template}-" + f"{args.target_model_path}-eval" ) + eval_cache_key = hashlib.md5(eval_cache_params_string.encode()).hexdigest() + with rank_0_priority(): + eval_dataset = Dataset.from_generator( + generator=safe_conversations_generator, + gen_kwargs={"file_path": args.eval_data_path}, + ) + eval_eagle3_dataset = build_eagle3_dataset( + eval_dataset, + tokenizer, + args.chat_template, + args.max_length, + cache_dir=os.path.join(args.cache_dir, "processed_dataset"), + cache_key=eval_cache_key, + is_vlm=args.is_vlm, + processor=processor, + num_proc=args.build_dataset_num_proc, + is_preformatted=args.is_preformatted, + train_only_last_turn=args.train_only_last_turn, + ) elif args.eval_hidden_states_path is not None: eval_eagle3_dataset = build_offline_eagle3_dataset( args.eval_hidden_states_path,