diff --git a/src/transformers/modeling_albert.py b/src/transformers/modeling_albert.py index c1540bda5f..4fae225212 100644 --- a/src/transformers/modeling_albert.py +++ b/src/transformers/modeling_albert.py @@ -579,6 +579,9 @@ class AlbertMLMHead(nn.Module): self.decoder = nn.Linear(config.embedding_size, config.vocab_size) self.activation = ACT2FN[config.hidden_act] + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.activation(hidden_states) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index cdc46b9662..48ada95c75 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -481,6 +481,9 @@ class BertLMPredictionHead(nn.Module): self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) + self.bias diff --git a/src/transformers/modeling_roberta.py b/src/transformers/modeling_roberta.py index 56e983e01c..fc066cc7b8 100644 --- a/src/transformers/modeling_roberta.py +++ b/src/transformers/modeling_roberta.py @@ -306,6 +306,9 @@ class RobertaLMHead(nn.Module): self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + def forward(self, features, **kwargs): x = self.dense(features) x = gelu(x) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 719debcb3c..1281febd9f 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -487,6 +487,8 @@ class ModelTesterMixin: self.assertEqual(model.config.vocab_size, model_vocab_size + 10) # Check that it actually resizes the embeddings matrix self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10) + # Check that the model can still do a forward pass successfully (every parameter should be resized) + model(**inputs_dict) # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size model_embed = model.resize_token_embeddings(model_vocab_size - 15) @@ -494,6 +496,11 @@ class ModelTesterMixin: # Check that it actually resizes the embeddings matrix self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15) + # Check that the model can still do a forward pass successfully (every parameter should be resized) + # Input ids should be clamped to the maximum size of the vocabulary + inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1) + model(**inputs_dict) + # Check that adding and removing tokens has not modified the first part of the embedding matrix. models_equal = True for p1, p2 in zip(cloned_embeddings, model_embed.weight):