model.tie_weights() should be applied after accelerator.prepare() (#18676)
* `model.tie_weights()` should be applied after `accelerator.prepare` Weight tying should be done after the model has been moved to XLA device as mentioned on PyTorch/XLA Troubleshooting guide [here](https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks) * format code
This commit is contained in:
@@ -518,10 +518,6 @@ def main():
|
|||||||
]
|
]
|
||||||
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||||
|
|
||||||
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
|
|
||||||
if accelerator.distributed_type == DistributedType.TPU:
|
|
||||||
model.tie_weights()
|
|
||||||
|
|
||||||
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
|
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
|
||||||
# shorter in multiprocess)
|
# shorter in multiprocess)
|
||||||
|
|
||||||
@@ -544,6 +540,10 @@ def main():
|
|||||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
|
||||||
|
if accelerator.distributed_type == DistributedType.TPU:
|
||||||
|
model.tie_weights()
|
||||||
|
|
||||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
if overrode_max_train_steps:
|
if overrode_max_train_steps:
|
||||||
|
|||||||
Reference in New Issue
Block a user