From 5a4a76edb7ac6bbc764392e89adc11adda91f3e5 Mon Sep 17 00:00:00 2001 From: bayllama <142558246+bayllama@users.noreply.github.com> Date: Tue, 23 Jul 2024 02:28:44 -0700 Subject: [PATCH] Modify resize_token_embeddings to ensure output type is same as input (#31979) * Change resize_token_embeddings to make it return same Class that is passed to it * Add explanatory comment as requested in review * Add explanatory comments for add resizing function in lxmert * Add comment for padding_idx and moving _resize_bias in lxmert to LxmertForPreTraining --------- Co-authored-by: Prashanth Sateesh Co-authored-by: Prashanth Sateesh --- src/transformers/modeling_utils.py | 13 ++++++++++++- .../models/lxmert/modeling_lxmert.py | 16 ++++++++++++++++ tests/test_modeling_common.py | 5 +++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a20b7d941f..81403f524f 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2128,7 +2128,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix else: new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] - return new_embeddings + # Replace weights in old_embeddings and return to maintain the same embedding type. + # This ensures correct functionality when a Custom Embedding class is passed as input. + # The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979) + old_embeddings.weight.data = new_embeddings.weight.data + old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0] + + # If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx` + # will be set to `None` in the resized embeddings. + if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx: + old_embeddings.padding_idx = None + + return old_embeddings def _get_resized_lm_head( self, old_lm_head: nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False diff --git a/src/transformers/models/lxmert/modeling_lxmert.py b/src/transformers/models/lxmert/modeling_lxmert.py index b77b873183..9113fc4fd0 100644 --- a/src/transformers/models/lxmert/modeling_lxmert.py +++ b/src/transformers/models/lxmert/modeling_lxmert.py @@ -1072,6 +1072,22 @@ class LxmertForPreTraining(LxmertPreTrainedModel): } self.visual_losses = visual_losses + def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: + # Adding the following steps to resize bias to match the shape of resized embeddings + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + self.cls.predictions.bias = self._resize_bias(self.cls.predictions.bias, new_num_tokens) + return new_embeddings + + def _resize_bias(self, bias, new_num_tokens: int): + old_num_tokens = bias.shape[0] + if new_num_tokens <= old_num_tokens: + new_bias = bias[:new_num_tokens] + else: + extra_bias = torch.zeros(new_num_tokens - old_num_tokens, device=bias.device) + new_bias = torch.cat([bias, extra_bias]) + new_bias = nn.Parameter(new_bias) + return new_bias + def resize_num_qa_labels(self, num_labels): """ Build a resized question answering linear layer Module from a provided new linear layer. Increasing the size diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index dd041188cd..19a945aec5 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1755,6 +1755,8 @@ class ModelTesterMixin: config = copy.deepcopy(original_config) model = model_class(config) model.to(torch_device) + model_embed_pre_resize = model.get_input_embeddings() + type_model_embed_pre_resize = type(model_embed_pre_resize) if self.model_tester.is_training is False: model.eval() @@ -1774,6 +1776,9 @@ class ModelTesterMixin: self.assertEqual(new_model_vocab_size, model_vocab_size + 10) # Check that it actually resizes the embeddings matrix self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10) + # Check to make sure the type of embeddings returned post resizing is same as type of input + type_model_embed_post_resize = type(model_embed) + self.assertEqual(type_model_embed_pre_resize, type_model_embed_post_resize) # Check that the model can still do a forward pass successfully (every parameter should be resized) model(**self._prepare_for_class(inputs_dict, model_class))