From dd49404a897f84622d38254fe90cd07d8c1640b0 Mon Sep 17 00:00:00 2001 From: Hwijeen Ahn Date: Tue, 18 Jul 2023 10:33:41 -0700 Subject: [PATCH] check if eval dataset is dict (#24877) * check if eval dataset is dict * formatting --- examples/pytorch/summarization/run_summarization.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index d981bdbd15..8168482ca7 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -692,7 +692,13 @@ def main(): results = {} if training_args.do_eval: logger.info("*** Evaluate ***") - metrics = trainer.evaluate(metric_key_prefix="eval") + if isinstance(eval_dataset, dict): + metrics = {} + for eval_ds_name, eval_ds in eval_dataset.items(): + dataset_metrics = trainer.evaluate(eval_dataset=eval_ds, metric_key_prefix=f"eval_{eval_ds_name}") + metrics.update(dataset_metrics) + else: + metrics = trainer.evaluate(metric_key_prefix="eval") max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))