check if eval dataset is dict (#24877)
* check if eval dataset is dict * formatting
This commit is contained in:
@@ -692,7 +692,13 @@ def main():
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
if isinstance(eval_dataset, dict):
|
||||
metrics = {}
|
||||
for eval_ds_name, eval_ds in eval_dataset.items():
|
||||
dataset_metrics = trainer.evaluate(eval_dataset=eval_ds, metric_key_prefix=f"eval_{eval_ds_name}")
|
||||
metrics.update(dataset_metrics)
|
||||
else:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user