Tpu tie weights (#13030)
* Fix tied weights on TPU * Manually tie weights in no trainer examples * Fix for test * One last missing * Gettning owned by my scripts * Address review comments * Fix test * Fix tests * Fix reformer tests
This commit is contained in:
@@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader
|
|||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator, DistributedType
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
@@ -403,6 +403,10 @@ def main():
|
|||||||
model, optimizer, train_dataloader, eval_dataloader
|
model, optimizer, train_dataloader, eval_dataloader
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
|
||||||
|
if accelerator.distributed_type == DistributedType.TPU:
|
||||||
|
model.tie_weights()
|
||||||
|
|
||||||
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
|
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
|
||||||
# shorter in multiprocess)
|
# shorter in multiprocess)
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader
|
|||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator, DistributedType
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
@@ -448,6 +448,10 @@ def main():
|
|||||||
model, optimizer, train_dataloader, eval_dataloader
|
model, optimizer, train_dataloader, eval_dataloader
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
|
||||||
|
if accelerator.distributed_type == DistributedType.TPU:
|
||||||
|
model.tie_weights()
|
||||||
|
|
||||||
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
|
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
|
||||||
# shorter in multiprocess)
|
# shorter in multiprocess)
|
||||||
|
|
||||||
|
|||||||
@@ -594,6 +594,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
self = getattr(self, self.base_model_prefix)
|
self = getattr(self, self.base_model_prefix)
|
||||||
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
|
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
|
||||||
|
|
||||||
|
for module in self.modules():
|
||||||
|
if hasattr(module, "_tie_weights"):
|
||||||
|
module._tie_weights()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str):
|
def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str):
|
||||||
uninitialized_encoder_weights: List[str] = []
|
uninitialized_encoder_weights: List[str] = []
|
||||||
|
|||||||
@@ -860,8 +860,6 @@ class AlbertMLMHead(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
|
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
|
||||||
self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
|
self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
|
||||||
self.activation = ACT2FN[config.hidden_act]
|
self.activation = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
|
||||||
self.decoder.bias = self.bias
|
self.decoder.bias = self.bias
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
@@ -874,6 +872,10 @@ class AlbertMLMHead(nn.Module):
|
|||||||
|
|
||||||
return prediction_scores
|
return prediction_scores
|
||||||
|
|
||||||
|
def _tie_weights(self):
|
||||||
|
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||||
|
self.bias = self.decoder.bias
|
||||||
|
|
||||||
|
|
||||||
class AlbertSOPHead(nn.Module):
|
class AlbertSOPHead(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
|||||||
@@ -430,16 +430,18 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
|
|||||||
class BertGenerationOnlyLMHead(nn.Module):
|
class BertGenerationOnlyLMHead(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
||||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||||
|
|
||||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
|
||||||
self.decoder.bias = self.bias
|
self.decoder.bias = self.bias
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
logits = self.decoder(hidden_states)
|
logits = self.decoder(hidden_states)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
def _tie_weights(self):
|
||||||
|
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||||
|
self.bias = self.decoder.bias
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""BertGeneration Model with a `language modeling` head on top for CLM fine-tuning. """,
|
"""BertGeneration Model with a `language modeling` head on top for CLM fine-tuning. """,
|
||||||
|
|||||||
@@ -948,10 +948,8 @@ class IBertLMHead(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
||||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||||
|
|
||||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
|
||||||
self.decoder.bias = self.bias
|
self.decoder.bias = self.bias
|
||||||
|
|
||||||
def forward(self, features, **kwargs):
|
def forward(self, features, **kwargs):
|
||||||
@@ -964,6 +962,10 @@ class IBertLMHead(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def _tie_weights(self):
|
||||||
|
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||||
|
self.bias = self.decoder.bias
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1336,10 +1336,8 @@ class LongformerLMHead(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
||||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||||
|
|
||||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
|
||||||
self.decoder.bias = self.bias
|
self.decoder.bias = self.bias
|
||||||
|
|
||||||
def forward(self, features, **kwargs):
|
def forward(self, features, **kwargs):
|
||||||
@@ -1352,6 +1350,10 @@ class LongformerLMHead(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def _tie_weights(self):
|
||||||
|
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||||
|
self.bias = self.decoder.bias
|
||||||
|
|
||||||
|
|
||||||
class LongformerPreTrainedModel(PreTrainedModel):
|
class LongformerPreTrainedModel(PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1747,8 +1747,6 @@ class ReformerOnlyLMHead(nn.Module):
|
|||||||
self.chunk_size_lm_head = config.chunk_size_lm_head
|
self.chunk_size_lm_head = config.chunk_size_lm_head
|
||||||
self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False)
|
self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False)
|
||||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||||
|
|
||||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
|
||||||
self.decoder.bias = self.bias
|
self.decoder.bias = self.bias
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
@@ -1758,6 +1756,10 @@ class ReformerOnlyLMHead(nn.Module):
|
|||||||
hidden_states = self.decoder(hidden_states)
|
hidden_states = self.decoder(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def _tie_weights(self):
|
||||||
|
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||||
|
self.bias = self.decoder.bias
|
||||||
|
|
||||||
|
|
||||||
class ReformerPreTrainedModel(PreTrainedModel):
|
class ReformerPreTrainedModel(PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1124,10 +1124,8 @@ class RobertaLMHead(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
||||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||||
|
|
||||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
|
||||||
self.decoder.bias = self.bias
|
self.decoder.bias = self.bias
|
||||||
|
|
||||||
def forward(self, features, **kwargs):
|
def forward(self, features, **kwargs):
|
||||||
@@ -1140,6 +1138,10 @@ class RobertaLMHead(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def _tie_weights(self):
|
||||||
|
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||||
|
self.bias = self.decoder.bias
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -364,7 +364,7 @@ class Trainer:
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
if self.place_model_on_device:
|
if self.place_model_on_device:
|
||||||
model = model.to(args.device)
|
self._move_model_to_device(model, args.device)
|
||||||
|
|
||||||
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
|
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
|
||||||
if self.is_model_parallel:
|
if self.is_model_parallel:
|
||||||
@@ -505,6 +505,12 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
self.callback_handler.remove_callback(callback)
|
self.callback_handler.remove_callback(callback)
|
||||||
|
|
||||||
|
def _move_model_to_device(self, model, device):
|
||||||
|
model = model.to(device)
|
||||||
|
# Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
|
||||||
|
if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
|
||||||
|
model.tie_weights()
|
||||||
|
|
||||||
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
|
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
|
||||||
if not self.args.remove_unused_columns:
|
if not self.args.remove_unused_columns:
|
||||||
return dataset
|
return dataset
|
||||||
@@ -1017,7 +1023,7 @@ class Trainer:
|
|||||||
# do_train is not a reliable argument, as it might not be set and .train() still called, so
|
# do_train is not a reliable argument, as it might not be set and .train() still called, so
|
||||||
# the following is a workaround:
|
# the following is a workaround:
|
||||||
if args.fp16_full_eval and not args.do_train:
|
if args.fp16_full_eval and not args.do_train:
|
||||||
self.model = self.model.to(args.device)
|
self._move_model_to_device(self.model, args.device)
|
||||||
|
|
||||||
if "model_path" in kwargs:
|
if "model_path" in kwargs:
|
||||||
resume_from_checkpoint = kwargs.pop("model_path")
|
resume_from_checkpoint = kwargs.pop("model_path")
|
||||||
@@ -1078,7 +1084,7 @@ class Trainer:
|
|||||||
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
||||||
if model_reloaded:
|
if model_reloaded:
|
||||||
if self.place_model_on_device:
|
if self.place_model_on_device:
|
||||||
self.model = self.model.to(args.device)
|
self._move_model_to_device(self.model, args.device)
|
||||||
self.model_wrapped = self.model
|
self.model_wrapped = self.model
|
||||||
|
|
||||||
# Keeping track whether we can can len() on the dataset or not
|
# Keeping track whether we can can len() on the dataset or not
|
||||||
|
|||||||
Reference in New Issue
Block a user