[Examples] Fixes inconsistency around eval vs val and predict vs test (#11380)
* added changes for uniformity * modified files * corrected typo * fixed qa scripts * fix typos * fixed predict typo in qa no trainer * fixed test file * reverted trainer changes * reverted trainer changes in custom exmaples * updated readme * added changes in deepspeed test * added changes for predict and eval
This commit is contained in:
@@ -178,17 +178,17 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_test_samples: Optional[int] = field(
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
@@ -438,8 +438,8 @@ def main():
|
||||
if "validation" not in datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
eval_dataset = eval_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
@@ -452,10 +452,10 @@ def main():
|
||||
max_target_length = data_args.val_max_target_length
|
||||
if "test" not in datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
test_dataset = datasets["test"]
|
||||
if data_args.max_test_samples is not None:
|
||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||
test_dataset = test_dataset.map(
|
||||
predict_dataset = datasets["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||
predict_dataset = predict_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
@@ -547,37 +547,39 @@ def main():
|
||||
metrics = trainer.evaluate(
|
||||
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
|
||||
)
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Test ***")
|
||||
logger.info("*** Predict ***")
|
||||
|
||||
test_results = trainer.predict(
|
||||
test_dataset,
|
||||
metric_key_prefix="test",
|
||||
predict_results = trainer.predict(
|
||||
predict_dataset,
|
||||
metric_key_prefix="predict",
|
||||
max_length=data_args.val_max_target_length,
|
||||
num_beams=data_args.num_beams,
|
||||
)
|
||||
metrics = test_results.metrics
|
||||
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
||||
metrics = predict_results.metrics
|
||||
max_predict_samples = (
|
||||
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||
)
|
||||
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
|
||||
|
||||
trainer.log_metrics("test", metrics)
|
||||
trainer.save_metrics("test", metrics)
|
||||
trainer.log_metrics("predict", metrics)
|
||||
trainer.save_metrics("predict", metrics)
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
if training_args.predict_with_generate:
|
||||
test_preds = tokenizer.batch_decode(
|
||||
test_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
predictions = tokenizer.batch_decode(
|
||||
predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
test_preds = [pred.strip() for pred in test_preds]
|
||||
output_test_preds_file = os.path.join(training_args.output_dir, "test_generations.txt")
|
||||
with open(output_test_preds_file, "w") as writer:
|
||||
writer.write("\n".join(test_preds))
|
||||
predictions = [pred.strip() for pred in predictions]
|
||||
output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
|
||||
with open(output_prediction_file, "w") as writer:
|
||||
writer.write("\n".join(predictions))
|
||||
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub()
|
||||
|
||||
Reference in New Issue
Block a user