Properly calculate the total train iterations and recalculate num epochs in no_trainer scripts (#17856)

This commit is contained in:
Zachary Mueller
2022-06-23 15:46:01 -04:00
committed by GitHub
parent 7c1b91281f
commit 75259b44bf
12 changed files with 70 additions and 37 deletions

View File

@@ -546,8 +546,6 @@ def main():
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
else:
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
@@ -556,6 +554,9 @@ def main():
num_training_steps=args.max_train_steps,
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# 5. Train
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps