Tpu tie weights (#13030)

* Fix tied weights on TPU

* Manually tie weights in no trainer examples

* Fix for test

* One last missing

* Gettning owned by my scripts

* Address review comments

* Fix test

* Fix tests

* Fix reformer tests
This commit is contained in:
Sylvain Gugger
2021-08-06 20:41:39 +02:00
committed by GitHub
parent 1bf38611a4
commit 7fcee113c1
10 changed files with 51 additions and 21 deletions

View File

@@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm
import transformers
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
@@ -403,6 +403,10 @@ def main():
model, optimizer, train_dataloader, eval_dataloader
)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
# shorter in multiprocess)

View File

@@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm
import transformers
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
@@ -448,6 +448,10 @@ def main():
model, optimizer, train_dataloader, eval_dataloader
)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
# shorter in multiprocess)