Update no_trainer scripts with new Accelerate functionalities (#16617)
Adds logging and save/loading to the Accelerate scripts Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -150,6 +150,24 @@ def parse_args():
|
||||
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
|
||||
)
|
||||
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help="If the training should continue from a checkpoint folder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_tracking",
|
||||
required=False,
|
||||
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Sanity checks
|
||||
@@ -173,7 +191,8 @@ def main():
|
||||
args = parse_args()
|
||||
|
||||
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
|
||||
accelerator = Accelerator()
|
||||
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
|
||||
accelerator = Accelerator(log_with="all") if args.with_tracking else Accelerator()
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
@@ -376,14 +395,6 @@ def main():
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, eval_dataloader
|
||||
)
|
||||
|
||||
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
|
||||
# shorter in multiprocess)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
@@ -398,6 +409,23 @@ def main():
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# Figure out how many steps we should save the Accelerator states
|
||||
if hasattr(args.checkpointing_steps, "isdigit"):
|
||||
checkpointing_steps = args.checkpointing_steps
|
||||
if args.checkpointing_steps.isdigit():
|
||||
checkpointing_steps = int(args.checkpointing_steps)
|
||||
else:
|
||||
checkpointing_steps = None
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration
|
||||
if args.with_tracking:
|
||||
accelerator.init_trackers("glue_no_trainer", args)
|
||||
|
||||
# Get the metric function
|
||||
if args.task_name is not None:
|
||||
metric = load_metric("glue", args.task_name)
|
||||
@@ -417,12 +445,38 @@ def main():
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
completed_steps = 0
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
|
||||
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
|
||||
accelerator.load_state(args.resume_from_checkpoint)
|
||||
resume_step = None
|
||||
path = args.resume_from_checkpoint
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
|
||||
dirs.sort(key=os.path.getctime)
|
||||
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
|
||||
if "epoch" in path:
|
||||
args.num_train_epochs -= int(path.replace("epoch_", ""))
|
||||
else:
|
||||
resume_step = int(path.replace("step_", ""))
|
||||
args.num_train_epochs -= resume_step // len(train_dataloader)
|
||||
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step
|
||||
|
||||
for epoch in range(args.num_train_epochs):
|
||||
model.train()
|
||||
if args.with_tracking:
|
||||
total_loss = 0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# We need to skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == 0 and step < resume_step:
|
||||
continue
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
# We keep track of the loss at each epoch
|
||||
if args.with_tracking:
|
||||
total_loss += loss.detach().float()
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
||||
@@ -432,6 +486,10 @@ def main():
|
||||
progress_bar.update(1)
|
||||
completed_steps += 1
|
||||
|
||||
if isinstance(checkpointing_steps, int):
|
||||
if completed_steps % checkpointing_steps == 0:
|
||||
accelerator.save_state(f"step_{completed_steps}")
|
||||
|
||||
if completed_steps >= args.max_train_steps:
|
||||
break
|
||||
|
||||
@@ -447,6 +505,16 @@ def main():
|
||||
eval_metric = metric.compute()
|
||||
logger.info(f"epoch {epoch}: {eval_metric}")
|
||||
|
||||
if args.with_tracking:
|
||||
accelerator.log(
|
||||
{
|
||||
"accuracy" if args.task_name is not None else "glue": eval_metric,
|
||||
"train_loss": total_loss,
|
||||
"epoch": epoch,
|
||||
},
|
||||
step=completed_steps,
|
||||
)
|
||||
|
||||
if args.push_to_hub and epoch < args.num_train_epochs - 1:
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
@@ -457,6 +525,9 @@ def main():
|
||||
commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
|
||||
)
|
||||
|
||||
if args.checkpointing_steps == "epoch":
|
||||
accelerator.save_state(f"epoch_{epoch}")
|
||||
|
||||
if args.output_dir is not None:
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
|
||||
Reference in New Issue
Block a user