Tpu tie weights (#13030)
* Fix tied weights on TPU * Manually tie weights in no trainer examples * Fix for test * One last missing * Gettning owned by my scripts * Address review comments * Fix test * Fix tests * Fix reformer tests
This commit is contained in:
@@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
@@ -403,6 +403,10 @@ def main():
|
||||
model, optimizer, train_dataloader, eval_dataloader
|
||||
)
|
||||
|
||||
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
|
||||
if accelerator.distributed_type == DistributedType.TPU:
|
||||
model.tie_weights()
|
||||
|
||||
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
|
||||
# shorter in multiprocess)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
@@ -448,6 +448,10 @@ def main():
|
||||
model, optimizer, train_dataloader, eval_dataloader
|
||||
)
|
||||
|
||||
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
|
||||
if accelerator.distributed_type == DistributedType.TPU:
|
||||
model.tie_weights()
|
||||
|
||||
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
|
||||
# shorter in multiprocess)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user