[Example] Updating Question Answering examples for Predict Stage (#10792)
* added prediction stage and eval fix * style correction * removed extra lines
This commit is contained in:
@@ -100,6 +100,10 @@ class DataTrainingArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
||||||
)
|
)
|
||||||
|
test_file: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "An optional input test data file to evaluate the perplexity on (a text file)."},
|
||||||
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||||
)
|
)
|
||||||
@@ -136,6 +140,13 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
max_test_samples: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||||
|
"value if set."
|
||||||
|
},
|
||||||
|
)
|
||||||
version_2_with_negative: bool = field(
|
version_2_with_negative: bool = field(
|
||||||
default=False, metadata={"help": "If true, some of the examples do not have an answer."}
|
default=False, metadata={"help": "If true, some of the examples do not have an answer."}
|
||||||
)
|
)
|
||||||
@@ -164,8 +175,13 @@ class DataTrainingArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
if (
|
||||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
self.dataset_name is None
|
||||||
|
and self.train_file is None
|
||||||
|
and self.validation_file is None
|
||||||
|
and self.test_file is None
|
||||||
|
):
|
||||||
|
raise ValueError("Need either a dataset name or a training/validation file/test_file.")
|
||||||
else:
|
else:
|
||||||
if self.train_file is not None:
|
if self.train_file is not None:
|
||||||
extension = self.train_file.split(".")[-1]
|
extension = self.train_file.split(".")[-1]
|
||||||
@@ -173,6 +189,9 @@ class DataTrainingArguments:
|
|||||||
if self.validation_file is not None:
|
if self.validation_file is not None:
|
||||||
extension = self.validation_file.split(".")[-1]
|
extension = self.validation_file.split(".")[-1]
|
||||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||||
|
if self.test_file is not None:
|
||||||
|
extension = self.test_file.split(".")[-1]
|
||||||
|
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -247,7 +266,9 @@ def main():
|
|||||||
if data_args.validation_file is not None:
|
if data_args.validation_file is not None:
|
||||||
data_files["validation"] = data_args.validation_file
|
data_files["validation"] = data_args.validation_file
|
||||||
extension = data_args.validation_file.split(".")[-1]
|
extension = data_args.validation_file.split(".")[-1]
|
||||||
|
if data_args.test_file is not None:
|
||||||
|
data_files["test"] = data_args.test_file
|
||||||
|
extension = data_args.test_file.split(".")[-1]
|
||||||
datasets = load_dataset(extension, data_files=data_files, field="data")
|
datasets = load_dataset(extension, data_files=data_files, field="data")
|
||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||||
@@ -291,8 +312,10 @@ def main():
|
|||||||
# Preprocessing is slighlty different for training and evaluation.
|
# Preprocessing is slighlty different for training and evaluation.
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
column_names = datasets["train"].column_names
|
column_names = datasets["train"].column_names
|
||||||
else:
|
elif training_args.do_eval:
|
||||||
column_names = datasets["validation"].column_names
|
column_names = datasets["validation"].column_names
|
||||||
|
else:
|
||||||
|
column_names = datasets["test"].column_names
|
||||||
question_column_name = "question" if "question" in column_names else column_names[0]
|
question_column_name = "question" if "question" in column_names else column_names[0]
|
||||||
context_column_name = "context" if "context" in column_names else column_names[1]
|
context_column_name = "context" if "context" in column_names else column_names[1]
|
||||||
answer_column_name = "answers" if "answers" in column_names else column_names[2]
|
answer_column_name = "answers" if "answers" in column_names else column_names[2]
|
||||||
@@ -444,12 +467,12 @@ def main():
|
|||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
if "validation" not in datasets:
|
if "validation" not in datasets:
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
eval_dataset = datasets["validation"]
|
eval_examples = datasets["validation"]
|
||||||
if data_args.max_val_samples is not None:
|
if data_args.max_val_samples is not None:
|
||||||
# We will select sample from whole data
|
# We will select sample from whole data
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
eval_examples = eval_examples.select(range(data_args.max_val_samples))
|
||||||
# Validation Feature Creation
|
# Validation Feature Creation
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_examples.map(
|
||||||
prepare_validation_features,
|
prepare_validation_features,
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
@@ -460,6 +483,25 @@ def main():
|
|||||||
# During Feature creation dataset samples might increase, we will select required samples again
|
# During Feature creation dataset samples might increase, we will select required samples again
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||||
|
|
||||||
|
if training_args.do_predict:
|
||||||
|
if "test" not in datasets:
|
||||||
|
raise ValueError("--do_predict requires a test dataset")
|
||||||
|
test_examples = datasets["test"]
|
||||||
|
if data_args.max_test_samples is not None:
|
||||||
|
# We will select sample from whole data
|
||||||
|
test_examples = test_examples.select(range(data_args.max_test_samples))
|
||||||
|
# Test Feature Creation
|
||||||
|
test_dataset = test_examples.map(
|
||||||
|
prepare_validation_features,
|
||||||
|
batched=True,
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
remove_columns=column_names,
|
||||||
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
|
)
|
||||||
|
if data_args.max_test_samples is not None:
|
||||||
|
# During Feature creation dataset samples might increase, we will select required samples again
|
||||||
|
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||||
|
|
||||||
# Data collator
|
# Data collator
|
||||||
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
||||||
# collator.
|
# collator.
|
||||||
@@ -470,7 +512,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Post-processing:
|
# Post-processing:
|
||||||
def post_processing_function(examples, features, predictions):
|
def post_processing_function(examples, features, predictions, stage="eval"):
|
||||||
# Post-processing: we match the start logits and end logits to answers in the original context.
|
# Post-processing: we match the start logits and end logits to answers in the original context.
|
||||||
predictions = postprocess_qa_predictions(
|
predictions = postprocess_qa_predictions(
|
||||||
examples=examples,
|
examples=examples,
|
||||||
@@ -482,6 +524,7 @@ def main():
|
|||||||
null_score_diff_threshold=data_args.null_score_diff_threshold,
|
null_score_diff_threshold=data_args.null_score_diff_threshold,
|
||||||
output_dir=training_args.output_dir,
|
output_dir=training_args.output_dir,
|
||||||
is_world_process_zero=trainer.is_world_process_zero(),
|
is_world_process_zero=trainer.is_world_process_zero(),
|
||||||
|
prefix=stage,
|
||||||
)
|
)
|
||||||
# Format the result to the format the metric expects.
|
# Format the result to the format the metric expects.
|
||||||
if data_args.version_2_with_negative:
|
if data_args.version_2_with_negative:
|
||||||
@@ -490,7 +533,8 @@ def main():
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
|
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
|
||||||
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in datasets["validation"]]
|
|
||||||
|
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
|
||||||
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
||||||
|
|
||||||
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
|
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
|
||||||
@@ -504,7 +548,7 @@ def main():
|
|||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=train_dataset if training_args.do_train else None,
|
train_dataset=train_dataset if training_args.do_train else None,
|
||||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||||
eval_examples=datasets["validation"] if training_args.do_eval else None,
|
eval_examples=eval_examples if training_args.do_eval else None,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
post_process_function=post_processing_function,
|
post_process_function=post_processing_function,
|
||||||
@@ -543,6 +587,18 @@ def main():
|
|||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
|
# Prediction
|
||||||
|
if training_args.do_predict:
|
||||||
|
logger.info("*** Predict ***")
|
||||||
|
results = trainer.predict(test_dataset, test_examples)
|
||||||
|
metrics = results.metrics
|
||||||
|
|
||||||
|
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
||||||
|
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
||||||
|
|
||||||
|
trainer.log_metrics("test", metrics)
|
||||||
|
trainer.save_metrics("test", metrics)
|
||||||
|
|
||||||
|
|
||||||
def _mp_fn(index):
|
def _mp_fn(index):
|
||||||
# For xla_spawn (TPUs)
|
# For xla_spawn (TPUs)
|
||||||
|
|||||||
@@ -99,6 +99,10 @@ class DataTrainingArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
||||||
)
|
)
|
||||||
|
test_file: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "An optional input test data file to test the perplexity on (a text file)."},
|
||||||
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||||
)
|
)
|
||||||
@@ -135,6 +139,13 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
max_test_samples: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||||
|
"value if set."
|
||||||
|
},
|
||||||
|
)
|
||||||
version_2_with_negative: bool = field(
|
version_2_with_negative: bool = field(
|
||||||
default=False, metadata={"help": "If true, some of the examples do not have an answer."}
|
default=False, metadata={"help": "If true, some of the examples do not have an answer."}
|
||||||
)
|
)
|
||||||
@@ -163,8 +174,13 @@ class DataTrainingArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
if (
|
||||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
self.dataset_name is None
|
||||||
|
and self.train_file is None
|
||||||
|
and self.validation_file is None
|
||||||
|
and self.test_file is None
|
||||||
|
):
|
||||||
|
raise ValueError("Need either a dataset name or a training/validation/test file.")
|
||||||
else:
|
else:
|
||||||
if self.train_file is not None:
|
if self.train_file is not None:
|
||||||
extension = self.train_file.split(".")[-1]
|
extension = self.train_file.split(".")[-1]
|
||||||
@@ -172,6 +188,9 @@ class DataTrainingArguments:
|
|||||||
if self.validation_file is not None:
|
if self.validation_file is not None:
|
||||||
extension = self.validation_file.split(".")[-1]
|
extension = self.validation_file.split(".")[-1]
|
||||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||||
|
if self.test_file is not None:
|
||||||
|
extension = self.test_file.split(".")[-1]
|
||||||
|
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -241,9 +260,13 @@ def main():
|
|||||||
data_files = {}
|
data_files = {}
|
||||||
if data_args.train_file is not None:
|
if data_args.train_file is not None:
|
||||||
data_files["train"] = data_args.train_file
|
data_files["train"] = data_args.train_file
|
||||||
|
extension = data_args.train_file.split(".")[-1]
|
||||||
if data_args.validation_file is not None:
|
if data_args.validation_file is not None:
|
||||||
data_files["validation"] = data_args.validation_file
|
data_files["validation"] = data_args.validation_file
|
||||||
extension = data_args.train_file.split(".")[-1]
|
extension = data_args.validation_file.split(".")[-1]
|
||||||
|
if data_args.test_file is not None:
|
||||||
|
data_files["test"] = data_args.test_file
|
||||||
|
extension = data_args.test_file.split(".")[-1]
|
||||||
datasets = load_dataset(extension, data_files=data_files, field="data")
|
datasets = load_dataset(extension, data_files=data_files, field="data")
|
||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||||
@@ -278,8 +301,10 @@ def main():
|
|||||||
# Preprocessing is slighlty different for training and evaluation.
|
# Preprocessing is slighlty different for training and evaluation.
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
column_names = datasets["train"].column_names
|
column_names = datasets["train"].column_names
|
||||||
else:
|
elif training_args.do_eval:
|
||||||
column_names = datasets["validation"].column_names
|
column_names = datasets["validation"].column_names
|
||||||
|
else:
|
||||||
|
column_names = datasets["test"].column_names
|
||||||
question_column_name = "question" if "question" in column_names else column_names[0]
|
question_column_name = "question" if "question" in column_names else column_names[0]
|
||||||
context_column_name = "context" if "context" in column_names else column_names[1]
|
context_column_name = "context" if "context" in column_names else column_names[1]
|
||||||
answer_column_name = "answers" if "answers" in column_names else column_names[2]
|
answer_column_name = "answers" if "answers" in column_names else column_names[2]
|
||||||
@@ -478,12 +503,12 @@ def main():
|
|||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
if "validation" not in datasets:
|
if "validation" not in datasets:
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
eval_dataset = datasets["validation"]
|
eval_examples = datasets["validation"]
|
||||||
if data_args.max_val_samples is not None:
|
if data_args.max_val_samples is not None:
|
||||||
# Selecting Eval Samples from Dataset
|
# Selecting Eval Samples from Dataset
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
eval_examples = eval_examples.select(range(data_args.max_val_samples))
|
||||||
# Create Features from Eval Dataset
|
# Create Features from Eval Dataset
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_examples.map(
|
||||||
prepare_validation_features,
|
prepare_validation_features,
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
@@ -494,6 +519,25 @@ def main():
|
|||||||
# Selecting Samples from Dataset again since Feature Creation might increase samples size
|
# Selecting Samples from Dataset again since Feature Creation might increase samples size
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||||
|
|
||||||
|
if training_args.do_predict:
|
||||||
|
if "test" not in datasets:
|
||||||
|
raise ValueError("--do_predict requires a test dataset")
|
||||||
|
test_examples = datasets["test"]
|
||||||
|
if data_args.max_test_samples is not None:
|
||||||
|
# We will select sample from whole data
|
||||||
|
test_examples = test_examples.select(range(data_args.max_test_samples))
|
||||||
|
# Test Feature Creation
|
||||||
|
test_dataset = test_examples.map(
|
||||||
|
prepare_validation_features,
|
||||||
|
batched=True,
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
remove_columns=column_names,
|
||||||
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
|
)
|
||||||
|
if data_args.max_test_samples is not None:
|
||||||
|
# During Feature creation dataset samples might increase, we will select required samples again
|
||||||
|
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||||
|
|
||||||
# Data collator
|
# Data collator
|
||||||
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
||||||
# collator.
|
# collator.
|
||||||
@@ -504,7 +548,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Post-processing:
|
# Post-processing:
|
||||||
def post_processing_function(examples, features, predictions):
|
def post_processing_function(examples, features, predictions, stage="eval"):
|
||||||
# Post-processing: we match the start logits and end logits to answers in the original context.
|
# Post-processing: we match the start logits and end logits to answers in the original context.
|
||||||
predictions, scores_diff_json = postprocess_qa_predictions_with_beam_search(
|
predictions, scores_diff_json = postprocess_qa_predictions_with_beam_search(
|
||||||
examples=examples,
|
examples=examples,
|
||||||
@@ -517,6 +561,7 @@ def main():
|
|||||||
end_n_top=model.config.end_n_top,
|
end_n_top=model.config.end_n_top,
|
||||||
output_dir=training_args.output_dir,
|
output_dir=training_args.output_dir,
|
||||||
is_world_process_zero=trainer.is_world_process_zero(),
|
is_world_process_zero=trainer.is_world_process_zero(),
|
||||||
|
prefix=stage,
|
||||||
)
|
)
|
||||||
# Format the result to the format the metric expects.
|
# Format the result to the format the metric expects.
|
||||||
if data_args.version_2_with_negative:
|
if data_args.version_2_with_negative:
|
||||||
@@ -526,7 +571,8 @@ def main():
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
|
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
|
||||||
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in datasets["validation"]]
|
|
||||||
|
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
|
||||||
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
||||||
|
|
||||||
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
|
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
|
||||||
@@ -540,7 +586,7 @@ def main():
|
|||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=train_dataset if training_args.do_train else None,
|
train_dataset=train_dataset if training_args.do_train else None,
|
||||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||||
eval_examples=datasets["validation"] if training_args.do_eval else None,
|
eval_examples=eval_examples if training_args.do_eval else None,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
post_process_function=post_processing_function,
|
post_process_function=post_processing_function,
|
||||||
@@ -580,6 +626,18 @@ def main():
|
|||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
|
# Prediction
|
||||||
|
if training_args.do_predict:
|
||||||
|
logger.info("*** Predict ***")
|
||||||
|
results = trainer.predict(test_dataset, test_examples)
|
||||||
|
metrics = results.metrics
|
||||||
|
|
||||||
|
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
||||||
|
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
||||||
|
|
||||||
|
trainer.log_metrics("test", metrics)
|
||||||
|
trainer.save_metrics("test", metrics)
|
||||||
|
|
||||||
|
|
||||||
def _mp_fn(index):
|
def _mp_fn(index):
|
||||||
# For xla_spawn (TPUs)
|
# For xla_spawn (TPUs)
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ class QuestionAnsweringTrainer(Trainer):
|
|||||||
if isinstance(test_dataset, datasets.Dataset):
|
if isinstance(test_dataset, datasets.Dataset):
|
||||||
test_dataset.set_format(type=test_dataset.format["type"], columns=list(test_dataset.features.keys()))
|
test_dataset.set_format(type=test_dataset.format["type"], columns=list(test_dataset.features.keys()))
|
||||||
|
|
||||||
eval_preds = self.post_process_function(test_examples, test_dataset, output.predictions)
|
eval_preds = self.post_process_function(test_examples, test_dataset, output.predictions, "test")
|
||||||
metrics = self.compute_metrics(eval_preds)
|
metrics = self.compute_metrics(eval_preds)
|
||||||
|
|
||||||
return PredictionOutput(predictions=eval_preds.predictions, label_ids=eval_preds.label_ids, metrics=metrics)
|
return PredictionOutput(predictions=eval_preds.predictions, label_ids=eval_preds.label_ids, metrics=metrics)
|
||||||
|
|||||||
@@ -215,14 +215,14 @@ def postprocess_qa_predictions(
|
|||||||
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
|
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
|
||||||
|
|
||||||
prediction_file = os.path.join(
|
prediction_file = os.path.join(
|
||||||
output_dir, "predictions.json" if prefix is None else f"predictions_{prefix}".json
|
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
|
||||||
)
|
)
|
||||||
nbest_file = os.path.join(
|
nbest_file = os.path.join(
|
||||||
output_dir, "nbest_predictions.json" if prefix is None else f"nbest_predictions_{prefix}".json
|
output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
|
||||||
)
|
)
|
||||||
if version_2_with_negative:
|
if version_2_with_negative:
|
||||||
null_odds_file = os.path.join(
|
null_odds_file = os.path.join(
|
||||||
output_dir, "null_odds.json" if prefix is None else f"null_odds_{prefix}".json
|
output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds_{prefix}.json"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Saving predictions to {prediction_file}.")
|
logger.info(f"Saving predictions to {prediction_file}.")
|
||||||
@@ -403,14 +403,14 @@ def postprocess_qa_predictions_with_beam_search(
|
|||||||
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
|
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
|
||||||
|
|
||||||
prediction_file = os.path.join(
|
prediction_file = os.path.join(
|
||||||
output_dir, "predictions.json" if prefix is None else f"predictions_{prefix}".json
|
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
|
||||||
)
|
)
|
||||||
nbest_file = os.path.join(
|
nbest_file = os.path.join(
|
||||||
output_dir, "nbest_predictions.json" if prefix is None else f"nbest_predictions_{prefix}".json
|
output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
|
||||||
)
|
)
|
||||||
if version_2_with_negative:
|
if version_2_with_negative:
|
||||||
null_odds_file = os.path.join(
|
null_odds_file = os.path.join(
|
||||||
output_dir, "null_odds.json" if prefix is None else f"null_odds_{prefix}".json
|
output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Saving predictions to {prediction_file}.")
|
print(f"Saving predictions to {prediction_file}.")
|
||||||
|
|||||||
Reference in New Issue
Block a user