Fixup no_trainer save logic (#16968)

* Fixup all examples
This commit is contained in:
Zachary Mueller
2022-04-27 14:46:49 -04:00
committed by GitHub
parent c79bbc3ba5
commit 60e1d883f1
12 changed files with 200 additions and 132 deletions

View File

@@ -393,32 +393,38 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
else:
resume_step = int(path.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step
for epoch in range(args.num_train_epochs): if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else:
resume_step = int(training_difference.replace("step_", ""))
starting_epoch = resume_step // len(train_dataloader)
resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss

View File

@@ -503,33 +503,39 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
else:
resume_step = int(path.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step
for epoch in range(args.num_train_epochs): if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else:
resume_step = int(training_difference.replace("step_", ""))
starting_epoch = resume_step // len(train_dataloader)
resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss

View File

@@ -549,33 +549,39 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
else:
resume_step = int(path.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step
for epoch in range(args.num_train_epochs): if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else:
resume_step = int(training_difference.replace("step_", ""))
starting_epoch = resume_step // len(train_dataloader)
resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss

View File

@@ -506,33 +506,39 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
else:
resume_step = int(path.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step
for epoch in range(args.num_train_epochs): if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else:
resume_step = int(training_difference.replace("step_", ""))
starting_epoch = resume_step // len(train_dataloader)
resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss

View File

@@ -765,33 +765,39 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
else:
resume_step = int(path.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step
for epoch in range(args.num_train_epochs): if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else:
resume_step = int(training_difference.replace("step_", ""))
starting_epoch = resume_step // len(train_dataloader)
resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss

View File

@@ -771,33 +771,39 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
else:
resume_step = int(path.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step
for epoch in range(args.num_train_epochs): if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else:
resume_step = int(training_difference.replace("step_", ""))
starting_epoch = resume_step // len(train_dataloader)
resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss

View File

@@ -501,33 +501,39 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
else:
resume_step = int(path.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step
for epoch in range(args.num_train_epochs): if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else:
resume_step = int(training_difference.replace("step_", ""))
starting_epoch = resume_step // len(train_dataloader)
resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(starting_epoch, args.num_train_epochs):
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
model.train() model.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss

View File

@@ -563,11 +563,13 @@ def main():
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}")
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
for epoch in range(args.num_train_epochs): starting_epoch = 0
for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# compute num of losses # compute num of losses

View File

@@ -569,32 +569,38 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
else:
resume_step = int(path.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step
for epoch in range(args.num_train_epochs): if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else:
resume_step = int(training_difference.replace("step_", ""))
starting_epoch = resume_step // len(train_dataloader)
resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss

View File

@@ -454,32 +454,38 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
else:
resume_step = int(path.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step
for epoch in range(args.num_train_epochs): if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else:
resume_step = int(training_difference.replace("step_", ""))
starting_epoch = resume_step // len(train_dataloader)
resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss

View File

@@ -606,32 +606,38 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
else:
resume_step = int(path.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step
for epoch in range(args.num_train_epochs): if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else:
resume_step = int(training_difference.replace("step_", ""))
starting_epoch = resume_step // len(train_dataloader)
resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss

View File

@@ -552,33 +552,39 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
else:
resume_step = int(path.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step
for epoch in range(args.num_train_epochs): if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else:
resume_step = int(training_difference.replace("step_", ""))
starting_epoch = resume_step // len(train_dataloader)
resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss