From e54a1b49aa6268c484625c6374f952f318914743 Mon Sep 17 00:00:00 2001 From: Atharva Ingle Date: Thu, 18 Aug 2022 23:16:57 +0530 Subject: [PATCH] `model.tie_weights()` should be applied after `accelerator.prepare()` (#18676) * `model.tie_weights()` should be applied after `accelerator.prepare` Weight tying should be done after the model has been moved to XLA device as mentioned on PyTorch/XLA Troubleshooting guide [here](https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks) * format code --- examples/pytorch/language-modeling/run_mlm_no_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/language-modeling/run_mlm_no_trainer.py b/examples/pytorch/language-modeling/run_mlm_no_trainer.py index c5f6aad412..c01f870cdd 100755 --- a/examples/pytorch/language-modeling/run_mlm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_mlm_no_trainer.py @@ -518,10 +518,6 @@ def main(): ] optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) - # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. - if accelerator.distributed_type == DistributedType.TPU: - model.tie_weights() - # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be # shorter in multiprocess) @@ -544,6 +540,10 @@ def main(): model, optimizer, train_dataloader, eval_dataloader, lr_scheduler ) + # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. + if accelerator.distributed_type == DistributedType.TPU: + model.tie_weights() + # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: