[Example] Updating Question Answering examples for Predict Stage (#10792)
* added prediction stage and eval fix * style correction * removed extra lines
This commit is contained in:
@@ -100,6 +100,10 @@ class DataTrainingArguments:
|
||||
default=None,
|
||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
||||
)
|
||||
test_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input test data file to evaluate the perplexity on (a text file)."},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
@@ -136,6 +140,13 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_test_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
version_2_with_negative: bool = field(
|
||||
default=False, metadata={"help": "If true, some of the examples do not have an answer."}
|
||||
)
|
||||
@@ -164,8 +175,13 @@ class DataTrainingArguments:
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
if (
|
||||
self.dataset_name is None
|
||||
and self.train_file is None
|
||||
and self.validation_file is None
|
||||
and self.test_file is None
|
||||
):
|
||||
raise ValueError("Need either a dataset name or a training/validation file/test_file.")
|
||||
else:
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
@@ -173,6 +189,9 @@ class DataTrainingArguments:
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
if self.test_file is not None:
|
||||
extension = self.test_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
|
||||
|
||||
|
||||
def main():
|
||||
@@ -247,7 +266,9 @@ def main():
|
||||
if data_args.validation_file is not None:
|
||||
data_files["validation"] = data_args.validation_file
|
||||
extension = data_args.validation_file.split(".")[-1]
|
||||
|
||||
if data_args.test_file is not None:
|
||||
data_files["test"] = data_args.test_file
|
||||
extension = data_args.test_file.split(".")[-1]
|
||||
datasets = load_dataset(extension, data_files=data_files, field="data")
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
@@ -291,8 +312,10 @@ def main():
|
||||
# Preprocessing is slighlty different for training and evaluation.
|
||||
if training_args.do_train:
|
||||
column_names = datasets["train"].column_names
|
||||
else:
|
||||
elif training_args.do_eval:
|
||||
column_names = datasets["validation"].column_names
|
||||
else:
|
||||
column_names = datasets["test"].column_names
|
||||
question_column_name = "question" if "question" in column_names else column_names[0]
|
||||
context_column_name = "context" if "context" in column_names else column_names[1]
|
||||
answer_column_name = "answers" if "answers" in column_names else column_names[2]
|
||||
@@ -444,12 +467,12 @@ def main():
|
||||
if training_args.do_eval:
|
||||
if "validation" not in datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = datasets["validation"]
|
||||
eval_examples = datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
# We will select sample from whole data
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
eval_examples = eval_examples.select(range(data_args.max_val_samples))
|
||||
# Validation Feature Creation
|
||||
eval_dataset = eval_dataset.map(
|
||||
eval_dataset = eval_examples.map(
|
||||
prepare_validation_features,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
@@ -460,6 +483,25 @@ def main():
|
||||
# During Feature creation dataset samples might increase, we will select required samples again
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
|
||||
if training_args.do_predict:
|
||||
if "test" not in datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
test_examples = datasets["test"]
|
||||
if data_args.max_test_samples is not None:
|
||||
# We will select sample from whole data
|
||||
test_examples = test_examples.select(range(data_args.max_test_samples))
|
||||
# Test Feature Creation
|
||||
test_dataset = test_examples.map(
|
||||
prepare_validation_features,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
if data_args.max_test_samples is not None:
|
||||
# During Feature creation dataset samples might increase, we will select required samples again
|
||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||
|
||||
# Data collator
|
||||
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
||||
# collator.
|
||||
@@ -470,7 +512,7 @@ def main():
|
||||
)
|
||||
|
||||
# Post-processing:
|
||||
def post_processing_function(examples, features, predictions):
|
||||
def post_processing_function(examples, features, predictions, stage="eval"):
|
||||
# Post-processing: we match the start logits and end logits to answers in the original context.
|
||||
predictions = postprocess_qa_predictions(
|
||||
examples=examples,
|
||||
@@ -482,6 +524,7 @@ def main():
|
||||
null_score_diff_threshold=data_args.null_score_diff_threshold,
|
||||
output_dir=training_args.output_dir,
|
||||
is_world_process_zero=trainer.is_world_process_zero(),
|
||||
prefix=stage,
|
||||
)
|
||||
# Format the result to the format the metric expects.
|
||||
if data_args.version_2_with_negative:
|
||||
@@ -490,7 +533,8 @@ def main():
|
||||
]
|
||||
else:
|
||||
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
|
||||
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in datasets["validation"]]
|
||||
|
||||
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
|
||||
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
||||
|
||||
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
|
||||
@@ -504,7 +548,7 @@ def main():
|
||||
args=training_args,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
eval_examples=datasets["validation"] if training_args.do_eval else None,
|
||||
eval_examples=eval_examples if training_args.do_eval else None,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
post_process_function=post_processing_function,
|
||||
@@ -543,6 +587,18 @@ def main():
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Prediction
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
results = trainer.predict(test_dataset, test_examples)
|
||||
metrics = results.metrics
|
||||
|
||||
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
||||
|
||||
trainer.log_metrics("test", metrics)
|
||||
trainer.save_metrics("test", metrics)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
|
||||
@@ -99,6 +99,10 @@ class DataTrainingArguments:
|
||||
default=None,
|
||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
||||
)
|
||||
test_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input test data file to test the perplexity on (a text file)."},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
@@ -135,6 +139,13 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_test_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
version_2_with_negative: bool = field(
|
||||
default=False, metadata={"help": "If true, some of the examples do not have an answer."}
|
||||
)
|
||||
@@ -163,8 +174,13 @@ class DataTrainingArguments:
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
if (
|
||||
self.dataset_name is None
|
||||
and self.train_file is None
|
||||
and self.validation_file is None
|
||||
and self.test_file is None
|
||||
):
|
||||
raise ValueError("Need either a dataset name or a training/validation/test file.")
|
||||
else:
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
@@ -172,6 +188,9 @@ class DataTrainingArguments:
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
if self.test_file is not None:
|
||||
extension = self.test_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
|
||||
|
||||
|
||||
def main():
|
||||
@@ -241,9 +260,13 @@ def main():
|
||||
data_files = {}
|
||||
if data_args.train_file is not None:
|
||||
data_files["train"] = data_args.train_file
|
||||
extension = data_args.train_file.split(".")[-1]
|
||||
if data_args.validation_file is not None:
|
||||
data_files["validation"] = data_args.validation_file
|
||||
extension = data_args.train_file.split(".")[-1]
|
||||
extension = data_args.validation_file.split(".")[-1]
|
||||
if data_args.test_file is not None:
|
||||
data_files["test"] = data_args.test_file
|
||||
extension = data_args.test_file.split(".")[-1]
|
||||
datasets = load_dataset(extension, data_files=data_files, field="data")
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
@@ -278,8 +301,10 @@ def main():
|
||||
# Preprocessing is slighlty different for training and evaluation.
|
||||
if training_args.do_train:
|
||||
column_names = datasets["train"].column_names
|
||||
else:
|
||||
elif training_args.do_eval:
|
||||
column_names = datasets["validation"].column_names
|
||||
else:
|
||||
column_names = datasets["test"].column_names
|
||||
question_column_name = "question" if "question" in column_names else column_names[0]
|
||||
context_column_name = "context" if "context" in column_names else column_names[1]
|
||||
answer_column_name = "answers" if "answers" in column_names else column_names[2]
|
||||
@@ -478,12 +503,12 @@ def main():
|
||||
if training_args.do_eval:
|
||||
if "validation" not in datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = datasets["validation"]
|
||||
eval_examples = datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
# Selecting Eval Samples from Dataset
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
eval_examples = eval_examples.select(range(data_args.max_val_samples))
|
||||
# Create Features from Eval Dataset
|
||||
eval_dataset = eval_dataset.map(
|
||||
eval_dataset = eval_examples.map(
|
||||
prepare_validation_features,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
@@ -494,6 +519,25 @@ def main():
|
||||
# Selecting Samples from Dataset again since Feature Creation might increase samples size
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
|
||||
if training_args.do_predict:
|
||||
if "test" not in datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
test_examples = datasets["test"]
|
||||
if data_args.max_test_samples is not None:
|
||||
# We will select sample from whole data
|
||||
test_examples = test_examples.select(range(data_args.max_test_samples))
|
||||
# Test Feature Creation
|
||||
test_dataset = test_examples.map(
|
||||
prepare_validation_features,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
if data_args.max_test_samples is not None:
|
||||
# During Feature creation dataset samples might increase, we will select required samples again
|
||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||
|
||||
# Data collator
|
||||
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
||||
# collator.
|
||||
@@ -504,7 +548,7 @@ def main():
|
||||
)
|
||||
|
||||
# Post-processing:
|
||||
def post_processing_function(examples, features, predictions):
|
||||
def post_processing_function(examples, features, predictions, stage="eval"):
|
||||
# Post-processing: we match the start logits and end logits to answers in the original context.
|
||||
predictions, scores_diff_json = postprocess_qa_predictions_with_beam_search(
|
||||
examples=examples,
|
||||
@@ -517,6 +561,7 @@ def main():
|
||||
end_n_top=model.config.end_n_top,
|
||||
output_dir=training_args.output_dir,
|
||||
is_world_process_zero=trainer.is_world_process_zero(),
|
||||
prefix=stage,
|
||||
)
|
||||
# Format the result to the format the metric expects.
|
||||
if data_args.version_2_with_negative:
|
||||
@@ -526,7 +571,8 @@ def main():
|
||||
]
|
||||
else:
|
||||
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
|
||||
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in datasets["validation"]]
|
||||
|
||||
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
|
||||
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
||||
|
||||
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
|
||||
@@ -540,7 +586,7 @@ def main():
|
||||
args=training_args,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
eval_examples=datasets["validation"] if training_args.do_eval else None,
|
||||
eval_examples=eval_examples if training_args.do_eval else None,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
post_process_function=post_processing_function,
|
||||
@@ -580,6 +626,18 @@ def main():
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Prediction
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
results = trainer.predict(test_dataset, test_examples)
|
||||
metrics = results.metrics
|
||||
|
||||
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
||||
|
||||
trainer.log_metrics("test", metrics)
|
||||
trainer.save_metrics("test", metrics)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
|
||||
@@ -98,7 +98,7 @@ class QuestionAnsweringTrainer(Trainer):
|
||||
if isinstance(test_dataset, datasets.Dataset):
|
||||
test_dataset.set_format(type=test_dataset.format["type"], columns=list(test_dataset.features.keys()))
|
||||
|
||||
eval_preds = self.post_process_function(test_examples, test_dataset, output.predictions)
|
||||
eval_preds = self.post_process_function(test_examples, test_dataset, output.predictions, "test")
|
||||
metrics = self.compute_metrics(eval_preds)
|
||||
|
||||
return PredictionOutput(predictions=eval_preds.predictions, label_ids=eval_preds.label_ids, metrics=metrics)
|
||||
|
||||
@@ -215,14 +215,14 @@ def postprocess_qa_predictions(
|
||||
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
|
||||
|
||||
prediction_file = os.path.join(
|
||||
output_dir, "predictions.json" if prefix is None else f"predictions_{prefix}".json
|
||||
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
|
||||
)
|
||||
nbest_file = os.path.join(
|
||||
output_dir, "nbest_predictions.json" if prefix is None else f"nbest_predictions_{prefix}".json
|
||||
output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
|
||||
)
|
||||
if version_2_with_negative:
|
||||
null_odds_file = os.path.join(
|
||||
output_dir, "null_odds.json" if prefix is None else f"null_odds_{prefix}".json
|
||||
output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds_{prefix}.json"
|
||||
)
|
||||
|
||||
logger.info(f"Saving predictions to {prediction_file}.")
|
||||
@@ -403,14 +403,14 @@ def postprocess_qa_predictions_with_beam_search(
|
||||
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
|
||||
|
||||
prediction_file = os.path.join(
|
||||
output_dir, "predictions.json" if prefix is None else f"predictions_{prefix}".json
|
||||
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
|
||||
)
|
||||
nbest_file = os.path.join(
|
||||
output_dir, "nbest_predictions.json" if prefix is None else f"nbest_predictions_{prefix}".json
|
||||
output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
|
||||
)
|
||||
if version_2_with_negative:
|
||||
null_odds_file = os.path.join(
|
||||
output_dir, "null_odds.json" if prefix is None else f"null_odds_{prefix}".json
|
||||
output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
|
||||
)
|
||||
|
||||
print(f"Saving predictions to {prediction_file}.")
|
||||
|
||||
Reference in New Issue
Block a user