From cc8407522a0443628e0b62827488fd6a17213f35 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 6 Feb 2023 19:34:34 -0500 Subject: [PATCH] Fix epoch number when resuming training (#21478) --- src/transformers/trainer.py | 4 +++- tests/trainer/test_trainer.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f10cc2f4b4..cdeef71aa6 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1798,8 +1798,10 @@ class Trainer: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False + steps_skipped = 0 if skip_first_batches is not None and steps_trained_in_current_epoch > 0: epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) + steps_skipped = steps_trained_in_current_epoch steps_trained_in_current_epoch = 0 rng_to_sync = True @@ -1907,7 +1909,7 @@ class Trainer: model.zero_grad() self.state.global_step += 1 - self.state.epoch = epoch + (step + 1) / steps_in_epoch + self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a09865448d..2c26deeffe 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1148,7 +1148,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): # won't be the same since the training dataloader is shuffled). with tempfile.TemporaryDirectory() as tmpdir: - kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1) + kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, logging_steps=5) trainer = get_regression_trainer(**kwargs) trainer.train() (a, b) = trainer.model.a.item(), trainer.model.b.item()