From 4c62c6021aa3341eff4c0bb87fea72fea22eb701 Mon Sep 17 00:00:00 2001 From: Yih-Dar Date: Tue, 15 Sep 2020 17:51:50 +0200 Subject: [PATCH] fix ZeroDivisionError and epoch counting (#7125) * fix ZeroDivisionError and epoch counting * Add test for num_train_epochs calculation in trainer.py * Remove @require_non_multigpu for test_num_train_epochs_in_training --- src/transformers/trainer.py | 14 +++++++------- tests/test_trainer.py | 15 +++++++++++++++ 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e2b4a854f1..2e314c896e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -606,13 +606,15 @@ class Trainer: # Data loader and number of training steps train_dataloader = self.get_train_dataloader() + num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) if self.args.max_steps > 0: t_total = self.args.max_steps - num_train_epochs = ( - self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1 + num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int( + self.args.max_steps % num_update_steps_per_epoch > 0 ) else: - t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) + t_total = int(num_update_steps_per_epoch * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs self.args.max_steps = t_total @@ -682,10 +684,8 @@ class Trainer: self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0]) self.total_flos = getattr(model.config, "total_flos", 0) - epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps) - steps_trained_in_current_epoch = self.global_step % ( - len(train_dataloader) // self.args.gradient_accumulation_steps - ) + epochs_trained = self.global_step // num_update_steps_per_epoch + steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch) logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(" Continuing training from epoch %d", epochs_trained) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 16bb56e88a..f5bbe9145b 100755 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -302,3 +302,18 @@ class TrainerIntegrationTest(unittest.TestCase): trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset) loader = trainer.get_train_dataloader() self.assertIsInstance(loader, torch.utils.data.DataLoader) + + def test_num_train_epochs_in_training(self): + # len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given. + # It should give 1 update step for each epoch. + trainer = get_regression_trainer( + max_steps=3, train_len=64, per_device_train_batch_size=16, gradient_accumulation_steps=5 + ) + train_output = trainer.train() + self.assertEqual(train_output.global_step, 3) + + # Even ``max_steps`` is not specified, we still expect 1 update step for each epoch if + # len(train_dl) < gradient_accumulation_steps. + trainer = get_regression_trainer(train_len=64, per_device_train_batch_size=16, gradient_accumulation_steps=5) + train_output = trainer.train() + self.assertEqual(train_output.global_step, int(self.n_epochs))