model.tie_weights() should be applied after accelerator.prepare() (#18676)

* `model.tie_weights()` should be applied after `accelerator.prepare`

Weight tying should be done after the model has been moved to XLA device as mentioned on PyTorch/XLA Troubleshooting guide [here](https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks)

* format code
This commit is contained in:
Atharva Ingle
2022-08-18 23:16:57 +05:30
committed by GitHub
parent bbbb453e58
commit e54a1b49aa

View File

@@ -518,10 +518,6 @@ def main():
] ]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
# shorter in multiprocess) # shorter in multiprocess)
@@ -544,6 +540,10 @@ def main():
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
) )
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps: if overrode_max_train_steps: