From ac17f71159c671b521bfed55c8bc1b5188fea3f8 Mon Sep 17 00:00:00 2001 From: Bhadresh Savani Date: Tue, 9 Mar 2021 22:36:56 +0530 Subject: [PATCH] added max_sample args and metrics changes (#10602) --- .../run_{{cookiecutter.example_shortcut}}.py | 88 +++++++++++++------ 1 file changed, 60 insertions(+), 28 deletions(-) diff --git a/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py b/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py index e6dc9ecc87..e2a2991445 100755 --- a/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py +++ b/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py @@ -144,6 +144,20 @@ class DataTrainingArguments: default=None, metadata={"help": "The number of processes to use for the preprocessing."}, ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_val_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of validation examples to this " + "value if set." + }, + ) def __post_init__(self): if self.dataset_name is None and self.train_file is None and self.validation_file is None: @@ -317,13 +331,37 @@ def main(): def tokenize_function(examples): return tokenizer(examples[text_column_name], padding="max_length", truncation=True) - tokenized_datasets = datasets.map( - tokenize_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=[text_column_name], - load_from_cache_file=not data_args.overwrite_cache, - ) + if training_args.do_train: + if "train" not in datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = datasets["train"] + if data_args.max_train_samples is not None: + # Select Sample from Dataset + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + # tokenize train dataset in batch + train_dataset = train_dataset.map( + tokenize_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=[text_column_name], + load_from_cache_file=not data_args.overwrite_cache, + ) + + if training_args.do_eval: + if "validation" not in datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = datasets["validation"] + # Selecting samples from dataset + if data_args.max_val_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) + # tokenize validation dataset + eval_dataset = eval_dataset.map( + tokenize_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=[text_column_name], + load_from_cache_file=not data_args.overwrite_cache, + ) # Data collator data_collator=default_data_collator if not training_args.fp16 else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) @@ -332,8 +370,8 @@ def main(): trainer = Trainer( model=model, args=training_args, - train_dataset=tokenized_datasets["train"] if training_args.do_train else None, - eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, tokenizer=tokenizer, data_collator=data_collator, ) @@ -358,33 +396,27 @@ def main(): train_result = trainer.train(resume_from_checkpoint=checkpoint) trainer.save_model() # Saves the tokenizer too for easy upload - output_train_file = os.path.join(training_args.output_dir, "train_results.txt") - if trainer.is_world_process_zero(): - with open(output_train_file, "w") as writer: - logger.info("***** Train results *****") - for key, value in sorted(train_result.metrics.items()): - logger.info(f" {key} = {value}") - writer.write(f"{key} = {value}\n") + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) - # Need to save the state, since Trainer.save_model saves only the tokenizer with the model - trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json")) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() # Evaluation - results = {} if training_args.do_eval: logger.info("*** Evaluate ***") - results = trainer.evaluate() + metrics = trainer.evaluate() - output_eval_file = os.path.join(training_args.output_dir, "eval_results_{{cookiecutter.example_shortcut}}.txt") - if trainer.is_world_process_zero(): - with open(output_eval_file, "w") as writer: - logger.info("***** Eval results *****") - for key, value in sorted(results.items()): - logger.info(f" {key} = {value}") - writer.write(f"{key} = {value}\n") + max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) + metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) - return results + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) def _mp_fn(index):