From fd1d9f1ab89805fb2a8e773edbc27531b449ddea Mon Sep 17 00:00:00 2001 From: Bhadresh Savani Date: Fri, 19 Mar 2021 19:12:17 +0530 Subject: [PATCH] [Example] Updating Question Answering examples for Predict Stage (#10792) * added prediction stage and eval fix * style correction * removed extra lines --- examples/question-answering/run_qa.py | 76 +++++++++++++++--- .../question-answering/run_qa_beam_search.py | 78 ++++++++++++++++--- examples/question-answering/trainer_qa.py | 2 +- examples/question-answering/utils_qa.py | 12 +-- 4 files changed, 141 insertions(+), 27 deletions(-) diff --git a/examples/question-answering/run_qa.py b/examples/question-answering/run_qa.py index 68d7177f1d..6e4821b1ad 100755 --- a/examples/question-answering/run_qa.py +++ b/examples/question-answering/run_qa.py @@ -100,6 +100,10 @@ class DataTrainingArguments: default=None, metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, ) + test_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input test data file to evaluate the perplexity on (a text file)."}, + ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) @@ -136,6 +140,13 @@ class DataTrainingArguments: "value if set." }, ) + max_test_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of test examples to this " + "value if set." + }, + ) version_2_with_negative: bool = field( default=False, metadata={"help": "If true, some of the examples do not have an answer."} ) @@ -164,8 +175,13 @@ class DataTrainingArguments: ) def __post_init__(self): - if self.dataset_name is None and self.train_file is None and self.validation_file is None: - raise ValueError("Need either a dataset name or a training/validation file.") + if ( + self.dataset_name is None + and self.train_file is None + and self.validation_file is None + and self.test_file is None + ): + raise ValueError("Need either a dataset name or a training/validation file/test_file.") else: if self.train_file is not None: extension = self.train_file.split(".")[-1] @@ -173,6 +189,9 @@ class DataTrainingArguments: if self.validation_file is not None: extension = self.validation_file.split(".")[-1] assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + if self.test_file is not None: + extension = self.test_file.split(".")[-1] + assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." def main(): @@ -247,7 +266,9 @@ def main(): if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.validation_file.split(".")[-1] - + if data_args.test_file is not None: + data_files["test"] = data_args.test_file + extension = data_args.test_file.split(".")[-1] datasets = load_dataset(extension, data_files=data_files, field="data") # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. @@ -291,8 +312,10 @@ def main(): # Preprocessing is slighlty different for training and evaluation. if training_args.do_train: column_names = datasets["train"].column_names - else: + elif training_args.do_eval: column_names = datasets["validation"].column_names + else: + column_names = datasets["test"].column_names question_column_name = "question" if "question" in column_names else column_names[0] context_column_name = "context" if "context" in column_names else column_names[1] answer_column_name = "answers" if "answers" in column_names else column_names[2] @@ -444,12 +467,12 @@ def main(): if training_args.do_eval: if "validation" not in datasets: raise ValueError("--do_eval requires a validation dataset") - eval_dataset = datasets["validation"] + eval_examples = datasets["validation"] if data_args.max_val_samples is not None: # We will select sample from whole data - eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) + eval_examples = eval_examples.select(range(data_args.max_val_samples)) # Validation Feature Creation - eval_dataset = eval_dataset.map( + eval_dataset = eval_examples.map( prepare_validation_features, batched=True, num_proc=data_args.preprocessing_num_workers, @@ -460,6 +483,25 @@ def main(): # During Feature creation dataset samples might increase, we will select required samples again eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) + if training_args.do_predict: + if "test" not in datasets: + raise ValueError("--do_predict requires a test dataset") + test_examples = datasets["test"] + if data_args.max_test_samples is not None: + # We will select sample from whole data + test_examples = test_examples.select(range(data_args.max_test_samples)) + # Test Feature Creation + test_dataset = test_examples.map( + prepare_validation_features, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + if data_args.max_test_samples is not None: + # During Feature creation dataset samples might increase, we will select required samples again + test_dataset = test_dataset.select(range(data_args.max_test_samples)) + # Data collator # We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data # collator. @@ -470,7 +512,7 @@ def main(): ) # Post-processing: - def post_processing_function(examples, features, predictions): + def post_processing_function(examples, features, predictions, stage="eval"): # Post-processing: we match the start logits and end logits to answers in the original context. predictions = postprocess_qa_predictions( examples=examples, @@ -482,6 +524,7 @@ def main(): null_score_diff_threshold=data_args.null_score_diff_threshold, output_dir=training_args.output_dir, is_world_process_zero=trainer.is_world_process_zero(), + prefix=stage, ) # Format the result to the format the metric expects. if data_args.version_2_with_negative: @@ -490,7 +533,8 @@ def main(): ] else: formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] - references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in datasets["validation"]] + + references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples] return EvalPrediction(predictions=formatted_predictions, label_ids=references) metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad") @@ -504,7 +548,7 @@ def main(): args=training_args, train_dataset=train_dataset if training_args.do_train else None, eval_dataset=eval_dataset if training_args.do_eval else None, - eval_examples=datasets["validation"] if training_args.do_eval else None, + eval_examples=eval_examples if training_args.do_eval else None, tokenizer=tokenizer, data_collator=data_collator, post_process_function=post_processing_function, @@ -543,6 +587,18 @@ def main(): trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) + # Prediction + if training_args.do_predict: + logger.info("*** Predict ***") + results = trainer.predict(test_dataset, test_examples) + metrics = results.metrics + + max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset) + metrics["test_samples"] = min(max_test_samples, len(test_dataset)) + + trainer.log_metrics("test", metrics) + trainer.save_metrics("test", metrics) + def _mp_fn(index): # For xla_spawn (TPUs) diff --git a/examples/question-answering/run_qa_beam_search.py b/examples/question-answering/run_qa_beam_search.py index 1aebde5c81..6005a479f2 100755 --- a/examples/question-answering/run_qa_beam_search.py +++ b/examples/question-answering/run_qa_beam_search.py @@ -99,6 +99,10 @@ class DataTrainingArguments: default=None, metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, ) + test_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input test data file to test the perplexity on (a text file)."}, + ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) @@ -135,6 +139,13 @@ class DataTrainingArguments: "value if set." }, ) + max_test_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of test examples to this " + "value if set." + }, + ) version_2_with_negative: bool = field( default=False, metadata={"help": "If true, some of the examples do not have an answer."} ) @@ -163,8 +174,13 @@ class DataTrainingArguments: ) def __post_init__(self): - if self.dataset_name is None and self.train_file is None and self.validation_file is None: - raise ValueError("Need either a dataset name or a training/validation file.") + if ( + self.dataset_name is None + and self.train_file is None + and self.validation_file is None + and self.test_file is None + ): + raise ValueError("Need either a dataset name or a training/validation/test file.") else: if self.train_file is not None: extension = self.train_file.split(".")[-1] @@ -172,6 +188,9 @@ class DataTrainingArguments: if self.validation_file is not None: extension = self.validation_file.split(".")[-1] assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + if self.test_file is not None: + extension = self.test_file.split(".")[-1] + assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." def main(): @@ -241,9 +260,13 @@ def main(): data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file + extension = data_args.train_file.split(".")[-1] if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file - extension = data_args.train_file.split(".")[-1] + extension = data_args.validation_file.split(".")[-1] + if data_args.test_file is not None: + data_files["test"] = data_args.test_file + extension = data_args.test_file.split(".")[-1] datasets = load_dataset(extension, data_files=data_files, field="data") # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. @@ -278,8 +301,10 @@ def main(): # Preprocessing is slighlty different for training and evaluation. if training_args.do_train: column_names = datasets["train"].column_names - else: + elif training_args.do_eval: column_names = datasets["validation"].column_names + else: + column_names = datasets["test"].column_names question_column_name = "question" if "question" in column_names else column_names[0] context_column_name = "context" if "context" in column_names else column_names[1] answer_column_name = "answers" if "answers" in column_names else column_names[2] @@ -478,12 +503,12 @@ def main(): if training_args.do_eval: if "validation" not in datasets: raise ValueError("--do_eval requires a validation dataset") - eval_dataset = datasets["validation"] + eval_examples = datasets["validation"] if data_args.max_val_samples is not None: # Selecting Eval Samples from Dataset - eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) + eval_examples = eval_examples.select(range(data_args.max_val_samples)) # Create Features from Eval Dataset - eval_dataset = eval_dataset.map( + eval_dataset = eval_examples.map( prepare_validation_features, batched=True, num_proc=data_args.preprocessing_num_workers, @@ -494,6 +519,25 @@ def main(): # Selecting Samples from Dataset again since Feature Creation might increase samples size eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) + if training_args.do_predict: + if "test" not in datasets: + raise ValueError("--do_predict requires a test dataset") + test_examples = datasets["test"] + if data_args.max_test_samples is not None: + # We will select sample from whole data + test_examples = test_examples.select(range(data_args.max_test_samples)) + # Test Feature Creation + test_dataset = test_examples.map( + prepare_validation_features, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + if data_args.max_test_samples is not None: + # During Feature creation dataset samples might increase, we will select required samples again + test_dataset = test_dataset.select(range(data_args.max_test_samples)) + # Data collator # We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data # collator. @@ -504,7 +548,7 @@ def main(): ) # Post-processing: - def post_processing_function(examples, features, predictions): + def post_processing_function(examples, features, predictions, stage="eval"): # Post-processing: we match the start logits and end logits to answers in the original context. predictions, scores_diff_json = postprocess_qa_predictions_with_beam_search( examples=examples, @@ -517,6 +561,7 @@ def main(): end_n_top=model.config.end_n_top, output_dir=training_args.output_dir, is_world_process_zero=trainer.is_world_process_zero(), + prefix=stage, ) # Format the result to the format the metric expects. if data_args.version_2_with_negative: @@ -526,7 +571,8 @@ def main(): ] else: formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] - references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in datasets["validation"]] + + references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples] return EvalPrediction(predictions=formatted_predictions, label_ids=references) metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad") @@ -540,7 +586,7 @@ def main(): args=training_args, train_dataset=train_dataset if training_args.do_train else None, eval_dataset=eval_dataset if training_args.do_eval else None, - eval_examples=datasets["validation"] if training_args.do_eval else None, + eval_examples=eval_examples if training_args.do_eval else None, tokenizer=tokenizer, data_collator=data_collator, post_process_function=post_processing_function, @@ -580,6 +626,18 @@ def main(): trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) + # Prediction + if training_args.do_predict: + logger.info("*** Predict ***") + results = trainer.predict(test_dataset, test_examples) + metrics = results.metrics + + max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset) + metrics["test_samples"] = min(max_test_samples, len(test_dataset)) + + trainer.log_metrics("test", metrics) + trainer.save_metrics("test", metrics) + def _mp_fn(index): # For xla_spawn (TPUs) diff --git a/examples/question-answering/trainer_qa.py b/examples/question-answering/trainer_qa.py index 04c8a976c7..db7b80c015 100644 --- a/examples/question-answering/trainer_qa.py +++ b/examples/question-answering/trainer_qa.py @@ -98,7 +98,7 @@ class QuestionAnsweringTrainer(Trainer): if isinstance(test_dataset, datasets.Dataset): test_dataset.set_format(type=test_dataset.format["type"], columns=list(test_dataset.features.keys())) - eval_preds = self.post_process_function(test_examples, test_dataset, output.predictions) + eval_preds = self.post_process_function(test_examples, test_dataset, output.predictions, "test") metrics = self.compute_metrics(eval_preds) return PredictionOutput(predictions=eval_preds.predictions, label_ids=eval_preds.label_ids, metrics=metrics) diff --git a/examples/question-answering/utils_qa.py b/examples/question-answering/utils_qa.py index aad5deccf9..9ce51e86fc 100644 --- a/examples/question-answering/utils_qa.py +++ b/examples/question-answering/utils_qa.py @@ -215,14 +215,14 @@ def postprocess_qa_predictions( assert os.path.isdir(output_dir), f"{output_dir} is not a directory." prediction_file = os.path.join( - output_dir, "predictions.json" if prefix is None else f"predictions_{prefix}".json + output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" ) nbest_file = os.path.join( - output_dir, "nbest_predictions.json" if prefix is None else f"nbest_predictions_{prefix}".json + output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json" ) if version_2_with_negative: null_odds_file = os.path.join( - output_dir, "null_odds.json" if prefix is None else f"null_odds_{prefix}".json + output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds_{prefix}.json" ) logger.info(f"Saving predictions to {prediction_file}.") @@ -403,14 +403,14 @@ def postprocess_qa_predictions_with_beam_search( assert os.path.isdir(output_dir), f"{output_dir} is not a directory." prediction_file = os.path.join( - output_dir, "predictions.json" if prefix is None else f"predictions_{prefix}".json + output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" ) nbest_file = os.path.join( - output_dir, "nbest_predictions.json" if prefix is None else f"nbest_predictions_{prefix}".json + output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json" ) if version_2_with_negative: null_odds_file = os.path.join( - output_dir, "null_odds.json" if prefix is None else f"null_odds_{prefix}".json + output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json" ) print(f"Saving predictions to {prediction_file}.")