From bdcc4b78a27775d1ec8f3fd297cb679c257289db Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 1 Oct 2020 14:13:29 -0400 Subject: [PATCH] Fix seq2seq example test (#7518) * Fix seq2seq example test * Fix bad copy-paste * Also save the state --- examples/seq2seq/finetune_trainer.py | 1 + examples/seq2seq/test_finetune_trainer.py | 7 +++---- src/transformers/trainer.py | 1 + 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/seq2seq/finetune_trainer.py b/examples/seq2seq/finetune_trainer.py index c1192c971f..049f0f9ca5 100644 --- a/examples/seq2seq/finetune_trainer.py +++ b/examples/seq2seq/finetune_trainer.py @@ -276,6 +276,7 @@ def main(): # For convenience, we also re-save the tokenizer to the same directory, # so that you can share your model easily on huggingface.co/models =) if trainer.is_world_process_zero(): + trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json")) tokenizer.save_pretrained(training_args.output_dir) # Evaluation diff --git a/examples/seq2seq/test_finetune_trainer.py b/examples/seq2seq/test_finetune_trainer.py index 80f8d699b0..156f51b1ed 100644 --- a/examples/seq2seq/test_finetune_trainer.py +++ b/examples/seq2seq/test_finetune_trainer.py @@ -4,11 +4,10 @@ import tempfile from unittest.mock import patch from transformers.testing_utils import slow -from transformers.trainer_utils import set_seed +from transformers.trainer_utils import TrainerState, set_seed from .finetune_trainer import main from .test_seq2seq_examples import MBART_TINY -from .utils import load_json set_seed(42) @@ -17,7 +16,7 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" def test_finetune_trainer(): output_dir = run_trainer(1, "12", MBART_TINY, 1) - logs = load_json(os.path.join(output_dir, "log_history.json")) + logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history eval_metrics = [log for log in logs if "eval_loss" in log.keys()] first_step_stats = eval_metrics[0] assert "eval_bleu" in first_step_stats @@ -30,7 +29,7 @@ def test_finetune_trainer_slow(): output_dir = run_trainer(eval_steps=2, max_len="32", model_name=MARIAN_MODEL, num_train_epochs=3) # Check metrics - logs = load_json(os.path.join(output_dir, "log_history.json")) + logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history eval_metrics = [log for log in logs if "eval_loss" in log.keys()] first_step_stats = eval_metrics[0] last_step_stats = eval_metrics[-1] diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index b13f9dbc19..12409338f4 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -601,6 +601,7 @@ class Trainer: output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") self.save_model(output_dir) if self.is_world_master(): + self.state.save_to_json(os.path.join(output_dir, "trainer_state.json")) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))