Fix epoch number when resuming training (#21478)

This commit is contained in:
Sylvain Gugger
2023-02-06 19:34:34 -05:00
committed by GitHub
parent 35f93f299f
commit cc8407522a
2 changed files with 4 additions and 2 deletions

View File

@@ -1148,7 +1148,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# won't be the same since the training dataloader is shuffled).
with tempfile.TemporaryDirectory() as tmpdir:
kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, logging_steps=5)
trainer = get_regression_trainer(**kwargs)
trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()