diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index d981bdbd15..8168482ca7 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -692,7 +692,13 @@ def main(): results = {} if training_args.do_eval: logger.info("*** Evaluate ***") - metrics = trainer.evaluate(metric_key_prefix="eval") + if isinstance(eval_dataset, dict): + metrics = {} + for eval_ds_name, eval_ds in eval_dataset.items(): + dataset_metrics = trainer.evaluate(eval_dataset=eval_ds, metric_key_prefix=f"eval_{eval_ds_name}") + metrics.update(dataset_metrics) + else: + metrics = trainer.evaluate(metric_key_prefix="eval") max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))