From 7ef40120a0041fed2c43901794cd85b2246455d2 Mon Sep 17 00:00:00 2001 From: Bhadresh Savani Date: Tue, 23 Mar 2021 23:07:59 +0530 Subject: [PATCH] [Examples] Added predict stage and Updated Example Template (#10868) * added predict stage * added test keyword in exception message * removed example specific saving predictions * fixed f-string error * removed extra line Co-authored-by: Stas Bekman Co-authored-by: Stas Bekman --- examples/text-classification/run_xnli.py | 56 ++++++++++++++--- .../run_{{cookiecutter.example_shortcut}}.py | 61 +++++++++++++++++-- 2 files changed, 103 insertions(+), 14 deletions(-) diff --git a/examples/text-classification/run_xnli.py b/examples/text-classification/run_xnli.py index 21870879c1..2b95e0ca95 100755 --- a/examples/text-classification/run_xnli.py +++ b/examples/text-classification/run_xnli.py @@ -207,14 +207,22 @@ def main(): # In distributed training, the load_dataset function guarantees that only one local process can concurrently # download the dataset. # Downloading and loading xnli dataset from the hub. - if model_args.train_language is None: - train_dataset = load_dataset("xnli", model_args.language, split="train") - else: - train_dataset = load_dataset("xnli", model_args.train_language, split="train") + if training_args.do_train: + if model_args.train_language is None: + train_dataset = load_dataset("xnli", model_args.language, split="train") + else: + train_dataset = load_dataset("xnli", model_args.train_language, split="train") + label_list = train_dataset.features["label"].names + + if training_args.do_eval: + eval_dataset = load_dataset("xnli", model_args.language, split="validation") + label_list = eval_dataset.features["label"].names + + if training_args.do_predict: + test_dataset = load_dataset("xnli", model_args.language, split="test") + label_list = test_dataset.features["label"].names - eval_dataset = load_dataset("xnli", model_args.language, split="validation") # Labels - label_list = train_dataset.features["label"].names num_labels = len(label_list) # Load pretrained model and tokenizer @@ -271,6 +279,9 @@ def main(): batched=True, load_from_cache_file=not data_args.overwrite_cache, ) + # Log a few random samples from the training set: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") if training_args.do_eval: if data_args.max_val_samples is not None: @@ -281,9 +292,14 @@ def main(): load_from_cache_file=not data_args.overwrite_cache, ) - # Log a few random samples from the training set: - for index in random.sample(range(len(train_dataset)), 3): - logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + if training_args.do_predict: + if data_args.max_test_samples is not None: + test_dataset = test_dataset.select(range(data_args.max_test_samples)) + test_dataset = test_dataset.map( + preprocess_function, + batched=True, + load_from_cache_file=not data_args.overwrite_cache, + ) # Get the metric function metric = load_metric("xnli") @@ -307,7 +323,7 @@ def main(): trainer = Trainer( model=model, args=training_args, - train_dataset=train_dataset, + train_dataset=train_dataset if training_args.do_train else None, eval_dataset=eval_dataset if training_args.do_eval else None, compute_metrics=compute_metrics, tokenizer=tokenizer, @@ -346,6 +362,26 @@ def main(): trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) + # Prediction + if training_args.do_predict: + logger.info("*** Predict ***") + predictions, labels, metrics = trainer.predict(test_dataset) + + max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset) + metrics["test_samples"] = min(max_test_samples, len(test_dataset)) + + trainer.log_metrics("test", metrics) + trainer.save_metrics("test", metrics) + + predictions = np.argmax(predictions, axis=1) + output_test_file = os.path.join(training_args.output_dir, "test_predictions.txt") + if trainer.is_world_process_zero(): + with open(output_test_file, "w") as writer: + writer.write("index\tprediction\n") + for index, item in enumerate(predictions): + item = label_list[item] + writer.write(f"{index}\t{item}\n") + if __name__ == "__main__": main() diff --git a/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py b/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py index 4614d3a1fb..33d87345b1 100755 --- a/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py +++ b/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py @@ -139,6 +139,10 @@ class DataTrainingArguments: default=None, metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, ) + test_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input test data file to predict the label on (a text file)."}, + ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) @@ -160,10 +164,22 @@ class DataTrainingArguments: "value if set." }, ) + max_test_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of test examples to this " + "value if set." + }, + ) def __post_init__(self): - if self.dataset_name is None and self.train_file is None and self.validation_file is None: - raise ValueError("Need either a dataset name or a training/validation file.") + if ( + self.dataset_name is None + and self.train_file is None + and self.validation_file is None + and self.test_file is None + ): + raise ValueError("Need either a dataset name or a training/validation/test file.") else: if self.train_file is not None: extension = self.train_file.split(".")[-1] @@ -171,6 +187,9 @@ class DataTrainingArguments: if self.validation_file is not None: extension = self.validation_file.split(".")[-1] assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." + if self.test_file is not None: + extension = self.test_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`test_file` should be a csv, a json or a txt file." def main(): @@ -238,9 +257,13 @@ def main(): data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file + extension = data_args.train_file.split(".")[-1] if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file - extension = data_args.train_file.split(".")[-1] + extension = data_args.validation_file.split(".")[-1] + if data_args.test_file is not None: + data_files["test"] = data_args.test_file + extension = data_args.test_file.split(".")[-1] if extension == "txt": extension = "text" datasets = load_dataset(extension, data_files=data_files) @@ -326,8 +349,10 @@ def main(): # First we tokenize all the texts. if training_args.do_train: column_names = datasets["train"].column_names - else: + elif training_args.do_eval: column_names = datasets["validation"].column_names + elif training_args.do_predict: + column_names = datasets["test"].column_names text_column_name = "text" if "text" in column_names else column_names[0] def tokenize_function(examples): @@ -365,6 +390,22 @@ def main(): load_from_cache_file=not data_args.overwrite_cache, ) + if training_args.do_predict: + if "test" not in datasets: + raise ValueError("--do_predict requires a test dataset") + test_dataset = datasets["test"] + # Selecting samples from dataset + if data_args.max_test_samples is not None: + test_dataset = test_dataset.select(range(data_args.max_test_samples)) + # tokenize test dataset + test_dataset = test_dataset.map( + tokenize_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=[text_column_name], + load_from_cache_file=not data_args.overwrite_cache, + ) + # Data collator data_collator=default_data_collator if not training_args.fp16 else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) @@ -420,6 +461,18 @@ def main(): trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) + # Prediction + if training_args.do_predict: + logger.info("*** Predict ***") + predictions, labels, metrics = trainer.predict(test_dataset) + + max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset) + metrics["test_samples"] = min(max_test_samples, len(test_dataset)) + + trainer.log_metrics("test", metrics) + trainer.save_metrics("test", metrics) + + # write custom code for saving predictions according to task def _mp_fn(index): # For xla_spawn (TPUs)