From ee04b698223bc3279e0ce97bdd9d00aaa5d9cc38 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 26 Feb 2021 17:01:01 -0800 Subject: [PATCH] [examples] better model example (#10427) * refactors * typo --- examples/seq2seq/run_seq2seq.py | 26 ++++++------------- src/transformers/trainer.py | 2 +- src/transformers/trainer_pt_utils.py | 38 +++++++++++++++++++++++++++- 3 files changed, 46 insertions(+), 20 deletions(-) diff --git a/examples/seq2seq/run_seq2seq.py b/examples/seq2seq/run_seq2seq.py index ea6020c332..2a060dac52 100755 --- a/examples/seq2seq/run_seq2seq.py +++ b/examples/seq2seq/run_seq2seq.py @@ -572,7 +572,6 @@ def main(): compute_metrics=compute_metrics if training_args.predict_with_generate else None, ) - all_metrics = {} # Training if training_args.do_train: if last_checkpoint is not None: @@ -589,13 +588,10 @@ def main(): data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) ) metrics["train_samples"] = min(max_train_samples, len(train_dataset)) - if trainer.is_world_process_zero(): - trainer.log_metrics("train", metrics) - trainer.save_metrics("train", metrics) - all_metrics.update(metrics) - # Need to save the state, since Trainer.save_model saves only the tokenizer with the model - trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json")) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() # Evaluation results = {} @@ -608,10 +604,8 @@ def main(): max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) - if trainer.is_world_process_zero(): - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) - all_metrics.update(metrics) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) if training_args.do_predict: logger.info("*** Test ***") @@ -626,11 +620,10 @@ def main(): max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset) metrics["test_samples"] = min(max_test_samples, len(test_dataset)) - if trainer.is_world_process_zero(): - trainer.log_metrics("test", metrics) - trainer.save_metrics("test", metrics) - all_metrics.update(metrics) + trainer.log_metrics("test", metrics) + trainer.save_metrics("test", metrics) + if trainer.is_world_process_zero(): if training_args.predict_with_generate: test_preds = tokenizer.batch_decode( test_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True @@ -640,9 +633,6 @@ def main(): with open(output_test_preds_file, "w") as writer: writer.write("\n".join(test_preds)) - if trainer.is_world_process_zero(): - trainer.save_metrics("all", metrics) - return results diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 29850d27d6..009c9bff10 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -231,7 +231,7 @@ class Trainer: """ - from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics + from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state def __init__( self, diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index eac696ec35..9b76a23241 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -599,12 +599,16 @@ def log_metrics(self, split, metrics): """ Log metrics in a specially formatted way + Under distributed environment this is done only for a process with rank 0. + Args: split (:obj:`str`): Mode/split name: one of ``train``, ``eval``, ``test`` metrics (:obj:`Dict[str, float]`): The metrics returned from train/evaluate/predictmetrics: metrics dict """ + if not self.is_world_process_zero(): + return logger.info(f"***** {split} metrics *****") metrics_formatted = self.metrics_format(metrics) @@ -614,16 +618,48 @@ def log_metrics(self, split, metrics): logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}") -def save_metrics(self, split, metrics): +def save_metrics(self, split, metrics, combined=True): """ Save metrics into a json file for that split, e.g. ``train_results.json``. + Under distributed environment this is done only for a process with rank 0. + Args: split (:obj:`str`): Mode/split name: one of ``train``, ``eval``, ``test``, ``all`` metrics (:obj:`Dict[str, float]`): The metrics returned from train/evaluate/predict + combined (:obj:`bool`, `optional`, defaults to :obj:`True`): + Creates combined metrics by updating ``all_results.json`` with metrics of this call """ + if not self.is_world_process_zero(): + return + path = os.path.join(self.args.output_dir, f"{split}_results.json") with open(path, "w") as f: json.dump(metrics, f, indent=4, sort_keys=True) + + if combined: + path = os.path.join(self.args.output_dir, "all_results.json") + if os.path.exists(path): + with open(path, "r") as f: + all_metrics = json.load(f) + else: + all_metrics = {} + + all_metrics.update(metrics) + with open(path, "w") as f: + json.dump(all_metrics, f, indent=4, sort_keys=True) + + +def save_state(self): + """ + Saves the Trainer state, since Trainer.save_model saves only the tokenizer with the model + + Under distributed environment this is done only for a process with rank 0. + """ + if not self.is_world_process_zero(): + return + + path = os.path.join(self.args.output_dir, "trainer_state.json") + self.state.save_to_json(path)