Add multi-node conditions in trainer_qa.py and trainer_seq2seq.py (#19502)

* Add multi-node conditions in trainer_qa.py and trainer_seq2seq.py

* Code improvement
This commit is contained in:
regisss
2022-10-12 04:48:56 +02:00
committed by GitHub
parent 69b81c0a5f
commit bb2cfd1824
2 changed files with 13 additions and 5 deletions

View File

@@ -52,7 +52,8 @@ class QuestionAnsweringTrainer(Trainer):
finally:
self.compute_metrics = compute_metrics
if self.post_process_function is not None and self.compute_metrics is not None:
if self.post_process_function is not None and self.compute_metrics is not None and self.args.should_save:
# Only the main node write the results by default
eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions)
metrics = self.compute_metrics(eval_preds)
@@ -60,11 +61,13 @@ class QuestionAnsweringTrainer(Trainer):
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
self.log(metrics)
else:
metrics = {}
if self.args.should_log:
# Only the main node log the results by default
self.log(metrics)
if self.args.tpu_metrics_debug or self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())

View File

@@ -84,7 +84,8 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
)
)
if self.post_process_function is not None and self.compute_metrics is not None:
if self.post_process_function is not None and self.compute_metrics is not None and self.args.should_save:
# Only the main node write the results by default
eval_preds = self.post_process_function(eval_examples, eval_dataset, output)
metrics = self.compute_metrics(eval_preds)
@@ -94,8 +95,12 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
output.metrics.update(metrics)
else:
metrics = {}
self.log(metrics)
if self.args.should_log:
# Only the main node log the results by default
self.log(metrics)
if self.args.tpu_metrics_debug or self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)