diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 60da8f3997..be0056f9c0 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -605,14 +605,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): Return: :obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. """ - base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed - model_embeds = base_model._resize_token_embeddings(new_num_tokens) + model_embeds = self._resize_token_embeddings(new_num_tokens) if new_num_tokens is None: return model_embeds # Update base model and current model config self.config.vocab_size = new_num_tokens - base_model.vocab_size = new_num_tokens + self.vocab_size = new_num_tokens # Tie weights again if needed self.tie_weights() @@ -623,6 +622,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): old_embeddings = self.get_input_embeddings() new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) self.set_input_embeddings(new_embeddings) + + # if word embeddings are not tied, make sure that lm head is resized as well + if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings: + old_lm_head = self.get_output_embeddings() + new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens) + self.set_output_embeddings(new_lm_head) + return self.get_input_embeddings() def _get_resized_embeddings( @@ -653,9 +659,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): if old_num_tokens == new_num_tokens: return old_embeddings + if not isinstance(old_embeddings, nn.Embedding): + raise TypeError( + f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}." + f"You should either use a different resize function or make sure that `old_embeddings` are an instance of {nn.Embedding}." + ) + # Build new embeddings - new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) - new_embeddings.to(old_embeddings.weight.device) + new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim).to(self.device) # initialize all new embeddings (in particular added tokens) self._init_weights(new_embeddings) @@ -666,6 +677,68 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): return new_embeddings + def _get_resized_lm_head( + self, old_lm_head: torch.nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False + ) -> torch.nn.Linear: + """ + Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end + + Args: + old_lm_head (:obj:`torch.nn.Linear`): + Old lm head liner layer to be resized. + new_num_tokens (:obj:`int`, `optional`): + New number of tokens in the linear matrix. + + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove + vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens + :obj:`torch.nn.Linear`` module of the model without doing anything. + transposed (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether ``old_lm_head`` is transposed or not. If True ``old_lm_head.size()`` is ``lm_head_dim, + vocab_size`` else ``vocab_size, lm_head_dim``. + + Return: + :obj:`torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if + :obj:`new_num_tokens` is :obj:`None` + """ + if new_num_tokens is None: + return old_lm_head + + old_num_tokens, old_lm_head_dim = ( + old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size() + ) + + if old_num_tokens == new_num_tokens: + return old_lm_head + + if not isinstance(old_lm_head, nn.Linear): + raise TypeError( + f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}." + f"You should either use a different resize function or make sure that `old_embeddings` are an instance of {nn.Linear}." + ) + + # Build new lm head + new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim) + has_new_lm_head_bias = old_lm_head.bias is not None + new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias).to(self.device) + + # initialize new lm head (in particular added tokens) + self._init_weights(new_lm_head) + + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + + # Copy old lm head weights to new lm head + if not transposed: + new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :] + else: + new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy] + + # Copy bias weights to new lm head + if has_new_lm_head_bias: + new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy] + + return new_lm_head + def init_weights(self): """ Initializes and prunes weights if needed. diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 30c3e1ae30..f0505f8078 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -632,12 +632,6 @@ class AlbertModel(AlbertPreTrainedModel): def set_input_embeddings(self, value): self.embeddings.word_embeddings = value - def _resize_token_embeddings(self, new_num_tokens): - old_embeddings = self.embeddings.word_embeddings - new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) - self.embeddings.word_embeddings = new_embeddings - return self.embeddings.word_embeddings - def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has @@ -748,6 +742,9 @@ class AlbertForPreTraining(AlbertPreTrainedModel): def get_output_embeddings(self): return self.predictions.decoder + def set_output_embeddings(self, new_embeddings): + self.predictions.decoder = new_embeddings + def get_input_embeddings(self): return self.albert.embeddings.word_embeddings @@ -889,6 +886,9 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): def get_output_embeddings(self): return self.predictions.decoder + def set_output_embeddings(self, new_embeddings): + self.predictions.decoder = new_embeddings + def get_input_embeddings(self): return self.albert.embeddings.word_embeddings diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 2da594e0dd..46649f2413 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -905,6 +905,9 @@ class BertForPreTraining(BertPreTrainedModel): def get_output_embeddings(self): return self.cls.predictions.decoder + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) def forward( @@ -1010,6 +1013,9 @@ class BertLMHeadModel(BertPreTrainedModel): def get_output_embeddings(self): return self.cls.predictions.decoder + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) def forward( @@ -1131,6 +1137,9 @@ class BertForMaskedLM(BertPreTrainedModel): def get_output_embeddings(self): return self.cls.predictions.decoder + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 7efe4422ad..b242181b87 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -422,6 +422,9 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): def get_output_embeddings(self): return self.lm_head.decoder + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + @add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index f85dd645ad..76f0402d77 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -496,6 +496,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): def get_output_embeddings(self): return self.lm_head + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs): # only last token for inputs_ids if past is defined in kwargs if past: diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index df89a3bc1a..a724a61228 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -508,6 +508,9 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): def get_output_embeddings(self): return self.vocab_projector + def set_output_embeddings(self, new_embeddings): + self.vocab_projector = new_embeddings + @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 0886f51afd..dbd16516de 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -1003,6 +1003,9 @@ class ElectraForMaskedLM(ElectraPreTrainedModel): def get_output_embeddings(self): return self.generator_lm_head + def set_output_embeddings(self, word_embeddings): + self.generator_lm_head = word_embeddings + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 956ddfb0f8..5c27559fd8 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -194,6 +194,9 @@ class EncoderDecoderModel(PreTrainedModel): def get_output_embeddings(self): return self.decoder.get_output_embeddings() + def set_output_embeddings(self, new_embeddings): + return self.decoder.set_output_embeddings(new_embeddings) + @classmethod def from_encoder_decoder_pretrained( cls, diff --git a/src/transformers/models/funnel/modeling_funnel.py b/src/transformers/models/funnel/modeling_funnel.py index cfd8dada01..04d2c8c332 100644 --- a/src/transformers/models/funnel/modeling_funnel.py +++ b/src/transformers/models/funnel/modeling_funnel.py @@ -1167,6 +1167,9 @@ class FunnelForMaskedLM(FunnelPreTrainedModel): def get_output_embeddings(self): return self.lm_head + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index ae4edc80c9..cc3f288903 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -816,6 +816,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): def get_output_embeddings(self): return self.lm_head + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs @@ -945,6 +948,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): def get_output_embeddings(self): return self.lm_head + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index a689c43650..04779cccc5 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -781,6 +781,9 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel): def get_output_embeddings(self): return self.cls.predictions.decoder + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 056a258a9e..8cdc82ea06 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1632,6 +1632,9 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): def get_output_embeddings(self): return self.lm_head.decoder + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=LongformerMaskedLMOutput, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index 4d5a2da87c..0b9cdd7c3e 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -641,7 +641,7 @@ class MobileBertLMPredictionHead(nn.Module): def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0)) - hidden_states += self.bias + hidden_states += self.decoder.bias return hidden_states @@ -949,26 +949,16 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel): def get_output_embeddings(self): return self.cls.predictions.decoder - def tie_weights(self): - """ - Tie the weights between the input embeddings and the output embeddings. If the `torchscript` flag is set in the - configuration, can't handle parameter sharing so we are cloning the weights instead. - """ - output_embeddings = self.get_output_embeddings() - input_embeddings = self.get_input_embeddings() + def set_output_embeddings(self, new_embeddigs): + self.cls.predictions.decoder = new_embeddigs - resized_dense = nn.Linear( - input_embeddings.num_embeddings, self.config.hidden_size - self.config.embedding_size, bias=False + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding: + # resize dense output embedings at first + self.cls.predictions.dense = self._get_resized_lm_head( + self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True ) - kept_data = self.cls.predictions.dense.weight.data[ - ..., : min(self.cls.predictions.dense.weight.data.shape[1], resized_dense.weight.data.shape[1]) - ] - resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data - self.cls.predictions.dense = resized_dense - self.cls.predictions.dense.to(self.device) - if output_embeddings is not None and self.config.tie_word_embeddings: - self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) + return super().resize_token_embeddings(new_num_tokens=new_num_tokens) @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=MobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @@ -1067,26 +1057,15 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel): def get_output_embeddings(self): return self.cls.predictions.decoder - def tie_weights(self): - """ - Tie the weights between the input embeddings and the output embeddings. If the `torchscript` flag is set in the - configuration, can't handle parameter sharing so we are cloning the weights instead. - """ - output_embeddings = self.get_output_embeddings() - input_embeddings = self.get_input_embeddings() + def set_output_embeddings(self, new_embeddigs): + self.cls.predictions.decoder = new_embeddigs - resized_dense = nn.Linear( - input_embeddings.num_embeddings, self.config.hidden_size - self.config.embedding_size, bias=False + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding: + # resize dense output embedings at first + self.cls.predictions.dense = self._get_resized_lm_head( + self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True ) - kept_data = self.cls.predictions.dense.weight.data[ - ..., : min(self.cls.predictions.dense.weight.data.shape[1], resized_dense.weight.data.shape[1]) - ] - resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data - self.cls.predictions.dense = resized_dense - self.cls.predictions.dense.to(self.device) - - if output_embeddings is not None and self.config.tie_word_embeddings: - self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) + return super().resize_token_embeddings(new_num_tokens=new_num_tokens) @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index 46f609d896..2d41461382 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -542,6 +542,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): def get_output_embeddings(self): return self.lm_head + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, @@ -628,6 +631,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): def get_output_embeddings(self): return self.lm_head + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=OpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 7421ceffbe..b3af9c62b2 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -1703,6 +1703,9 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): def get_output_embeddings(self): return self.lm_head + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + def get_input_embeddings(self): return self.prophetnet.word_embeddings @@ -1901,6 +1904,9 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): def get_output_embeddings(self): return self.lm_head + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + @add_start_docstrings_to_model_forward(PROPHETNET_STANDALONE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=ProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index af1c0e6c67..1192c70d9e 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1459,6 +1459,9 @@ class RagTokenForGeneration(RagPreTrainedModel): def get_output_embeddings(self): return self.rag.generator.get_output_embeddings() + def set_output_embeddings(self, new_embeddings): + return self.rag.generator.set_output_embeddings(new_embeddings) + def shift_tokens_right(self, input_ids, start_token_id=None): """Shift input ids one token to the right, and pad with start_token_id""" if start_token_id is None: diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 29363122fe..f2bb57b457 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -2197,6 +2197,9 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): def get_output_embeddings(self): return self.lm_head.decoder + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + @add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, @@ -2309,6 +2312,9 @@ class ReformerForMaskedLM(ReformerPreTrainedModel): def get_output_embeddings(self): return self.lm_head.decoder + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + @add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index cc17f8a204..09606953cc 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -752,6 +752,9 @@ class RobertaForCausalLM(RobertaPreTrainedModel): def get_output_embeddings(self): return self.lm_head.decoder + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) def forward( @@ -873,6 +876,9 @@ class RobertaForMaskedLM(RobertaPreTrainedModel): def get_output_embeddings(self): return self.lm_head.decoder + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/models/squeezebert/modeling_squeezebert.py b/src/transformers/models/squeezebert/modeling_squeezebert.py index cb1fb812b9..072dd17dc7 100644 --- a/src/transformers/models/squeezebert/modeling_squeezebert.py +++ b/src/transformers/models/squeezebert/modeling_squeezebert.py @@ -655,6 +655,9 @@ class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel): def get_output_embeddings(self): return self.cls.predictions.decoder + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + @add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index f20b5cc5bc..5eb291eafb 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1363,6 +1363,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + def get_output_embeddings(self): return self.lm_head diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index b0667edab7..47c53a9bda 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -688,6 +688,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): def get_output_embeddings(self): return self.pred_layer.proj + def set_output_embeddings(self, new_embeddings): + self.pred_layer.proj = new_embeddings + def prepare_inputs_for_generation(self, input_ids, **kwargs): mask_token_id = self.config.mask_token_id lang_id = self.config.lang_id diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index b315fc8fe4..cf8d67695c 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -1312,6 +1312,9 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): def get_output_embeddings(self): return self.lm_loss + def set_output_embeddings(self, new_embeddings): + self.lm_loss = new_embeddings + def prepare_inputs_for_generation(self, input_ids, past=None, use_mems=None, **kwargs): # Add dummy token at the end (no attention on this one) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index de898390fc..775272b411 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -781,6 +781,9 @@ class {{cookiecutter.camelcase_modelname}}ForMaskedLM({{cookiecutter.camelcase_m def get_output_embeddings(self): return self.cls.predictions.decoder + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 2cadaa09c2..9d3bb3392a 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -208,6 +208,10 @@ class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): def test_inputs_embeds(self): pass + @unittest.skip("TODO: Decoder embeddings cannot be resized at the moment") + def test_resize_embeddings_untied(self): + pass + @require_sentencepiece @require_tokenizers def test_tiny_model(self): diff --git a/tests/test_modeling_blenderbot.py b/tests/test_modeling_blenderbot.py index 19fee17ba0..a81e62e4ef 100644 --- a/tests/test_modeling_blenderbot.py +++ b/tests/test_modeling_blenderbot.py @@ -128,6 +128,10 @@ class BlenderbotTesterMixin(ModelTesterMixin, unittest.TestCase): def test_feed_forward_chunking(self): pass + @unittest.skip("TODO: Decoder embeddings cannot be resized at the moment") + def test_resize_embeddings_untied(self): + pass + @unittest.skipUnless(torch_device != "cpu", "3B test too slow on CPU.") @require_torch diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 9375e11085..dc8dc075b3 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -815,6 +815,10 @@ class ModelTesterMixin: # Check that the model can still do a forward pass successfully (every parameter should be resized) # Input ids should be clamped to the maximum size of the vocabulary inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1) + + # make sure that decoder_input_ids are resized as well + if "decoder_input_ids" in inputs_dict: + inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1) model(**self._prepare_for_class(inputs_dict, model_class)) # Check that adding and removing tokens has not modified the first part of the embedding matrix. @@ -825,6 +829,57 @@ class ModelTesterMixin: self.assertTrue(models_equal) + def test_resize_embeddings_untied(self): + ( + original_config, + inputs_dict, + ) = self.model_tester.prepare_config_and_inputs_for_common() + if not self.test_resize_embeddings: + return + + original_config.tie_word_embeddings = False + + # if model cannot untied embeddings -> leave test + if original_config.tie_word_embeddings: + return + + for model_class in self.all_model_classes: + config = copy.deepcopy(original_config) + model = model_class(config).to(torch_device) + + # if no output embeddings -> leave test + if model.get_output_embeddings() is None: + continue + + # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size + model_vocab_size = config.vocab_size + model.resize_token_embeddings(model_vocab_size + 10) + self.assertEqual(model.config.vocab_size, model_vocab_size + 10) + output_embeds = model.get_output_embeddings() + self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10) + # Check bias if present + if output_embeds.bias is not None: + self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10) + # Check that the model can still do a forward pass successfully (every parameter should be resized) + model(**self._prepare_for_class(inputs_dict, model_class)) + + # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size + model.resize_token_embeddings(model_vocab_size - 15) + self.assertEqual(model.config.vocab_size, model_vocab_size - 15) + # Check that it actually resizes the embeddings matrix + output_embeds = model.get_output_embeddings() + self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15) + # Check bias if present + if output_embeds.bias is not None: + self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15) + # Check that the model can still do a forward pass successfully (every parameter should be resized) + # Input ids should be clamped to the maximum size of the vocabulary + inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1) + if "decoder_input_ids" in inputs_dict: + inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1) + # Check that the model can still do a forward pass successfully (every parameter should be resized) + model(**self._prepare_for_class(inputs_dict, model_class)) + def test_model_common_attributes(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_fsmt.py b/tests/test_modeling_fsmt.py index d5583a864f..60a52756ed 100644 --- a/tests/test_modeling_fsmt.py +++ b/tests/test_modeling_fsmt.py @@ -226,15 +226,9 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): def test_tie_model_weights(self): pass - # def test_auto_model(self): - # # XXX: add a tiny model to s3? - # model_name = "facebook/wmt19-ru-en-tiny" - # tiny = AutoModel.from_pretrained(model_name) # same vocab size - # tok = AutoTokenizer.from_pretrained(model_name) # same tokenizer - # inputs_dict = tok.batch_encode_plus(["Hello my friends"], return_tensors="pt") - - # with torch.no_grad(): - # tiny(**inputs_dict) + @unittest.skip("TODO: Decoder embeddings cannot be resized at the moment") + def test_resize_embeddings_untied(self): + pass @require_torch diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 788e1b8729..571da15f8b 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -574,6 +574,10 @@ class ReformerTesterMixin: # reformer cannot keep gradients in attentions or hidden states return + def test_resize_embeddings_untied(self): + # reformer cannot resize embeddings that easily + return + @require_torch class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, ModelTesterMixin, unittest.TestCase): diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 534187119f..15d5866c91 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -444,6 +444,20 @@ class T5ModelTester: ) ) + def check_resize_embeddings_t5_v1_1( + self, + config, + ): + prev_vocab_size = config.vocab_size + + config.tie_word_embeddings = False + model = T5ForConditionalGeneration(config=config).to(torch_device).eval() + model.resize_token_embeddings(prev_vocab_size - 10) + + self.parent.assertEqual(model.get_input_embeddings().weight.shape[0], prev_vocab_size - 10) + self.parent.assertEqual(model.get_output_embeddings().weight.shape[0], prev_vocab_size - 10) + self.parent.assertEqual(model.config.vocab_size, prev_vocab_size - 10) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -480,7 +494,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ) test_pruning = False test_torchscript = True - test_resize_embeddings = False + test_resize_embeddings = True test_model_parallel = True is_encoder_decoder = True @@ -536,6 +550,10 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) + def test_v1_1_resize_embeddings(self): + config = self.model_tester.prepare_config_and_inputs()[0] + self.model_tester.check_resize_embeddings_t5_v1_1(config) + @slow def test_model_from_pretrained(self): for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/test_modeling_transfo_xl.py b/tests/test_modeling_transfo_xl.py index e8bed3cfbd..cdb80a9135 100644 --- a/tests/test_modeling_transfo_xl.py +++ b/tests/test_modeling_transfo_xl.py @@ -299,6 +299,10 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC self.assertEqual(model_vocab_size, model.config.vocab_size) self.assertEqual(model_embed.emb_layers[layer].weight.shape[0], cloned_embeddings[layer].shape[0]) + def test_resize_embeddings_untied(self): + # transfo-xl requires special resize for lm-head + return + @require_torch class TransfoXLModelLanguageGenerationTest(unittest.TestCase):