Fix epoch number when resuming training (#21478)
This commit is contained in:
@@ -1798,8 +1798,10 @@ class Trainer:
|
||||
self._load_rng_state(resume_from_checkpoint)
|
||||
|
||||
rng_to_sync = False
|
||||
steps_skipped = 0
|
||||
if skip_first_batches is not None and steps_trained_in_current_epoch > 0:
|
||||
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
|
||||
steps_skipped = steps_trained_in_current_epoch
|
||||
steps_trained_in_current_epoch = 0
|
||||
rng_to_sync = True
|
||||
|
||||
@@ -1907,7 +1909,7 @@ class Trainer:
|
||||
|
||||
model.zero_grad()
|
||||
self.state.global_step += 1
|
||||
self.state.epoch = epoch + (step + 1) / steps_in_epoch
|
||||
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
|
||||
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
||||
|
||||
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
|
||||
|
||||
@@ -1148,7 +1148,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
# won't be the same since the training dataloader is shuffled).
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||
kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, logging_steps=5)
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
trainer.train()
|
||||
(a, b) = trainer.model.a.item(), trainer.model.b.item()
|
||||
|
||||
Reference in New Issue
Block a user