From bb2cfd18245a3abaef05d564e801a9f9f759feca Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Wed, 12 Oct 2022 04:48:56 +0200 Subject: [PATCH] Add multi-node conditions in trainer_qa.py and trainer_seq2seq.py (#19502) * Add multi-node conditions in trainer_qa.py and trainer_seq2seq.py * Code improvement --- examples/pytorch/question-answering/trainer_qa.py | 9 ++++++--- .../pytorch/question-answering/trainer_seq2seq_qa.py | 9 +++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/examples/pytorch/question-answering/trainer_qa.py b/examples/pytorch/question-answering/trainer_qa.py index 59d7a084c1..cdf8889a45 100644 --- a/examples/pytorch/question-answering/trainer_qa.py +++ b/examples/pytorch/question-answering/trainer_qa.py @@ -52,7 +52,8 @@ class QuestionAnsweringTrainer(Trainer): finally: self.compute_metrics = compute_metrics - if self.post_process_function is not None and self.compute_metrics is not None: + if self.post_process_function is not None and self.compute_metrics is not None and self.args.should_save: + # Only the main node write the results by default eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions) metrics = self.compute_metrics(eval_preds) @@ -60,11 +61,13 @@ class QuestionAnsweringTrainer(Trainer): for key in list(metrics.keys()): if not key.startswith(f"{metric_key_prefix}_"): metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) - - self.log(metrics) else: metrics = {} + if self.args.should_log: + # Only the main node log the results by default + self.log(metrics) + if self.args.tpu_metrics_debug or self.args.debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) diff --git a/examples/pytorch/question-answering/trainer_seq2seq_qa.py b/examples/pytorch/question-answering/trainer_seq2seq_qa.py index ab46435062..90acc05208 100644 --- a/examples/pytorch/question-answering/trainer_seq2seq_qa.py +++ b/examples/pytorch/question-answering/trainer_seq2seq_qa.py @@ -84,7 +84,8 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): ) ) - if self.post_process_function is not None and self.compute_metrics is not None: + if self.post_process_function is not None and self.compute_metrics is not None and self.args.should_save: + # Only the main node write the results by default eval_preds = self.post_process_function(eval_examples, eval_dataset, output) metrics = self.compute_metrics(eval_preds) @@ -94,8 +95,12 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) output.metrics.update(metrics) + else: + metrics = {} - self.log(metrics) + if self.args.should_log: + # Only the main node log the results by default + self.log(metrics) if self.args.tpu_metrics_debug or self.args.debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)