fix ZeroDivisionError and epoch counting (#7125)

* fix ZeroDivisionError and epoch counting

* Add test for num_train_epochs calculation in trainer.py

* Remove @require_non_multigpu for test_num_train_epochs_in_training
This commit is contained in:
Yih-Dar
2020-09-15 17:51:50 +02:00
committed by GitHub
parent 7af2791d77
commit 4c62c6021a
2 changed files with 22 additions and 7 deletions

View File

@@ -302,3 +302,18 @@ class TrainerIntegrationTest(unittest.TestCase):
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)
loader = trainer.get_train_dataloader()
self.assertIsInstance(loader, torch.utils.data.DataLoader)
def test_num_train_epochs_in_training(self):
# len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given.
# It should give 1 update step for each epoch.
trainer = get_regression_trainer(
max_steps=3, train_len=64, per_device_train_batch_size=16, gradient_accumulation_steps=5
)
train_output = trainer.train()
self.assertEqual(train_output.global_step, 3)
# Even ``max_steps`` is not specified, we still expect 1 update step for each epoch if
# len(train_dl) < gradient_accumulation_steps.
trainer = get_regression_trainer(train_len=64, per_device_train_batch_size=16, gradient_accumulation_steps=5)
train_output = trainer.train()
self.assertEqual(train_output.global_step, int(self.n_epochs))