[examples/s2s] add test set predictions (#10085)
* add do_predict, pass eval_beams durig eval * update help * apply suggestions from code review
This commit is contained in:
@@ -167,9 +167,22 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
max_test_samples: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||||
|
"value if set."
|
||||||
|
},
|
||||||
|
)
|
||||||
source_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."})
|
source_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."})
|
||||||
target_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."})
|
target_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."})
|
||||||
eval_beams: Optional[int] = field(default=None, metadata={"help": "Number of beams to use for evaluation."})
|
num_beams: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
||||||
|
"which is used during ``evaluate`` and ``predict``."
|
||||||
|
},
|
||||||
|
)
|
||||||
ignore_pad_token_for_loss: bool = field(
|
ignore_pad_token_for_loss: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={
|
metadata={
|
||||||
@@ -336,8 +349,13 @@ def main():
|
|||||||
# We need to tokenize inputs and targets.
|
# We need to tokenize inputs and targets.
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
column_names = datasets["train"].column_names
|
column_names = datasets["train"].column_names
|
||||||
else:
|
elif training_args.do_eval:
|
||||||
column_names = datasets["validation"].column_names
|
column_names = datasets["validation"].column_names
|
||||||
|
elif training_args.do_predict:
|
||||||
|
column_names = datasets["test"].column_names
|
||||||
|
else:
|
||||||
|
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
||||||
|
return
|
||||||
|
|
||||||
# For translation we set the codes of our source and target languages (only useful for mBART, the others will
|
# For translation we set the codes of our source and target languages (only useful for mBART, the others will
|
||||||
# ignore those attributes).
|
# ignore those attributes).
|
||||||
@@ -440,6 +458,19 @@ def main():
|
|||||||
load_from_cache_file=not data_args.overwrite_cache,
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if training_args.do_predict:
|
||||||
|
max_target_length = data_args.val_max_target_length
|
||||||
|
test_dataset = datasets["test"]
|
||||||
|
if data_args.max_test_samples is not None:
|
||||||
|
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||||
|
test_dataset = test_dataset.map(
|
||||||
|
preprocess_function,
|
||||||
|
batched=True,
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
remove_columns=column_names,
|
||||||
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
|
)
|
||||||
|
|
||||||
# Data collator
|
# Data collator
|
||||||
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||||
if data_args.pad_to_max_length:
|
if data_args.pad_to_max_length:
|
||||||
@@ -523,7 +554,7 @@ def main():
|
|||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
logger.info("*** Evaluate ***")
|
logger.info("*** Evaluate ***")
|
||||||
|
|
||||||
results = trainer.evaluate()
|
results = trainer.evaluate(max_length=data_args.val_max_target_length, num_beams=data_args.num_beams)
|
||||||
|
|
||||||
output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt")
|
output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt")
|
||||||
if trainer.is_world_process_zero():
|
if trainer.is_world_process_zero():
|
||||||
@@ -533,6 +564,34 @@ def main():
|
|||||||
logger.info(f" {key} = {value}")
|
logger.info(f" {key} = {value}")
|
||||||
writer.write(f"{key} = {value}\n")
|
writer.write(f"{key} = {value}\n")
|
||||||
|
|
||||||
|
if training_args.do_predict:
|
||||||
|
logger.info("*** Test ***")
|
||||||
|
|
||||||
|
test_results = trainer.predict(
|
||||||
|
test_dataset,
|
||||||
|
metric_key_prefix="test",
|
||||||
|
max_length=data_args.val_max_target_length,
|
||||||
|
num_beams=data_args.num_beams,
|
||||||
|
)
|
||||||
|
test_metrics = test_results.metrics
|
||||||
|
|
||||||
|
output_test_result_file = os.path.join(training_args.output_dir, "test_results_seq2seq.txt")
|
||||||
|
if trainer.is_world_process_zero():
|
||||||
|
with open(output_test_result_file, "w") as writer:
|
||||||
|
logger.info("***** Test results *****")
|
||||||
|
for key, value in sorted(test_metrics.items()):
|
||||||
|
logger.info(f" {key} = {value}")
|
||||||
|
writer.write(f"{key} = {value}\n")
|
||||||
|
|
||||||
|
if training_args.predict_with_generate:
|
||||||
|
test_preds = tokenizer.batch_decode(
|
||||||
|
test_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||||
|
)
|
||||||
|
test_preds = [pred.strip() for pred in test_preds]
|
||||||
|
output_test_preds_file = os.path.join(training_args.output_dir, "test_preds_seq2seq.txt")
|
||||||
|
with open(output_test_preds_file, "w") as writer:
|
||||||
|
writer.write("\n".join(test_preds))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user