fix ZeroDivisionError and epoch counting (#7125)

* fix ZeroDivisionError and epoch counting

* Add test for num_train_epochs calculation in trainer.py

* Remove @require_non_multigpu for test_num_train_epochs_in_training
This commit is contained in:
Yih-Dar
2020-09-15 17:51:50 +02:00
committed by GitHub
parent 7af2791d77
commit 4c62c6021a
2 changed files with 22 additions and 7 deletions

View File

@@ -606,13 +606,15 @@ class Trainer:
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
if self.args.max_steps > 0:
t_total = self.args.max_steps
num_train_epochs = (
self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
self.args.max_steps % num_update_steps_per_epoch > 0
)
else:
t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
t_total = int(num_update_steps_per_epoch * self.args.num_train_epochs)
num_train_epochs = self.args.num_train_epochs
self.args.max_steps = t_total
@@ -682,10 +684,8 @@ class Trainer:
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
self.total_flos = getattr(model.config, "total_flos", 0)
epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
steps_trained_in_current_epoch = self.global_step % (
len(train_dataloader) // self.args.gradient_accumulation_steps
)
epochs_trained = self.global_step // num_update_steps_per_epoch
steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)