From c61f116b639ef7ca1ada8ca06822d0bfb50c3890 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 1 Sep 2022 12:06:56 -0400 Subject: [PATCH] Tie weights after preparing the model in run_clm (#18855) --- examples/pytorch/language-modeling/run_clm_no_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/language-modeling/run_clm_no_trainer.py b/examples/pytorch/language-modeling/run_clm_no_trainer.py index f5ea78f832..dee0fee8a0 100755 --- a/examples/pytorch/language-modeling/run_clm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_clm_no_trainer.py @@ -477,10 +477,6 @@ def main(): ] optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) - # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. - if accelerator.distributed_type == DistributedType.TPU: - model.tie_weights() - # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -500,6 +496,10 @@ def main(): model, optimizer, train_dataloader, eval_dataloader, lr_scheduler ) + # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. + if accelerator.distributed_type == DistributedType.TPU: + model.tie_weights() + # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: