[MNLI example] Prevent overwriting matched with mismatched metrics (#16475)
* Prevent overwriting matched with mismatched metrics * Fix style
This commit is contained in:
@@ -507,6 +507,7 @@ def main():
|
||||
if data_args.task_name == "mnli":
|
||||
tasks.append("mnli-mm")
|
||||
eval_datasets.append(raw_datasets["validation_mismatched"])
|
||||
combined = {}
|
||||
|
||||
for eval_dataset, task in zip(eval_datasets, tasks):
|
||||
metrics = trainer.evaluate(eval_dataset=eval_dataset)
|
||||
@@ -516,8 +517,13 @@ def main():
|
||||
)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
if task == "mnli-mm":
|
||||
metrics = {k + "_mm": v for k, v in metrics.items()}
|
||||
if "mnli" in task:
|
||||
combined.update(metrics)
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", combined if "mnli" in task else metrics)
|
||||
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
|
||||
Reference in New Issue
Block a user