streamlining 'checkpointing_steps' parsing (#18755)
This commit is contained in:
@@ -451,12 +451,9 @@ def main():
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# Figure out how many steps we should save the Accelerator states
|
||||
if hasattr(args.checkpointing_steps, "isdigit"):
|
||||
checkpointing_steps = args.checkpointing_steps
|
||||
if args.checkpointing_steps.isdigit():
|
||||
checkpointing_steps = int(args.checkpointing_steps)
|
||||
else:
|
||||
checkpointing_steps = None
|
||||
checkpointing_steps = args.checkpointing_steps
|
||||
if checkpointing_steps is not None and checkpointing_steps.isdigit():
|
||||
checkpointing_steps = int(checkpointing_steps)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
|
||||
Reference in New Issue
Block a user