Update all no_trainer with skip_first_batches (#23664)

This commit is contained in:
Zachary Mueller
2023-05-22 14:49:31 -04:00
committed by GitHub
parent 26a06814a1
commit b191d7db44
12 changed files with 115 additions and 98 deletions

View File

@@ -668,12 +668,12 @@ def main():
model.train()
if args.with_tracking:
total_loss = 0
for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We skip the first `n` batches in the dataloader when resuming from a checkpoint
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
else:
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
outputs = model(**batch)
loss = outputs.loss
# We keep track of the loss at each epoch