From d6bf08f7f6f8bf5d6d94e9b10b7d8203906353ad Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 17 Aug 2023 17:00:32 +0200 Subject: [PATCH] [`resize_embedding`] Introduce `pad_to_multiple_of` and guidance (#25088) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix * revert cahnges and update resizing of embedding layer * use wraning * fixup * more styling nits * fix all tests that overload the embedding tests * 👀👀 remove breakpoint * remove useless overload + overload correctly where needed * resize lm head with new vocab size * reverse not necessary changes * style * fix CIs! * fix last CI tests, adapt bark and Marian * fixup --- src/transformers/modeling_utils.py | 49 ++++++++++++++++--- src/transformers/models/bark/modeling_bark.py | 13 +++-- src/transformers/models/bart/modeling_bart.py | 6 +-- .../modeling_bigbird_pegasus.py | 6 +-- .../models/blenderbot/modeling_blenderbot.py | 6 +-- .../modeling_blenderbot_small.py | 6 +-- .../modeling_gptsan_japanese.py | 6 +-- src/transformers/models/led/modeling_led.py | 6 +-- .../models/m2m_100/modeling_m2m_100.py | 4 -- .../models/marian/modeling_marian.py | 9 ++-- .../models/mbart/modeling_mbart.py | 6 +-- src/transformers/models/mvp/modeling_mvp.py | 4 +- .../models/nllb_moe/modeling_nllb_moe.py | 4 -- .../models/pegasus/modeling_pegasus.py | 6 +-- .../models/pegasus_x/modeling_pegasus_x.py | 4 -- .../models/plbart/modeling_plbart.py | 6 +-- .../speech_to_text/modeling_speech_to_text.py | 4 -- .../models/speecht5/modeling_speecht5.py | 4 -- .../models/whisper/modeling_whisper.py | 4 -- tests/test_modeling_common.py | 20 ++++++++ 20 files changed, 107 insertions(+), 66 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 66e3de6480..545e87c5fe 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1382,7 +1382,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): output_embeddings.out_features = input_embeddings.num_embeddings - def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding: + def resize_token_embeddings( + self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None + ) -> nn.Embedding: """ Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. @@ -1393,11 +1395,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix The number of new tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. + pad_to_multiple_of (`int`, *optional*): + If set will pad the embedding matrix to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more + details about this, or help on choosing the correct value for resizing, refer to this guide: + https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc Return: `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. """ - model_embeds = self._resize_token_embeddings(new_num_tokens) + model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of) if new_num_tokens is None: return model_embeds @@ -1410,21 +1419,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix return model_embeds - def _resize_token_embeddings(self, new_num_tokens): + def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): old_embeddings = self.get_input_embeddings() - new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) + new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of) self.set_input_embeddings(new_embeddings) # if word embeddings are not tied, make sure that lm head is resized as well if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings: old_lm_head = self.get_output_embeddings() - new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens) + new_lm_head = self._get_resized_lm_head(old_lm_head, new_embeddings.weight.shape[0]) self.set_output_embeddings(new_lm_head) return self.get_input_embeddings() def _get_resized_embeddings( - self, old_embeddings: nn.Embedding, new_num_tokens: Optional[int] = None + self, + old_embeddings: nn.Embedding, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, ) -> nn.Embedding: """ Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly @@ -1439,11 +1451,36 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. + pad_to_multiple_of (`int`, *optional*): + If set will pad the embedding matrix to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more + details about this, or help on choosing the correct value for resizing, refer to this guide: + https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc + Return: `torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if `new_num_tokens` is `None` """ + + if pad_to_multiple_of is not None: + if not isinstance(pad_to_multiple_of, int): + raise ValueError( + f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer" + ) + if new_num_tokens is None: + new_num_tokens = old_embeddings.weight.shape[0] + new_num_tokens = ((new_num_tokens // pad_to_multiple_of) + 1) * pad_to_multiple_of + else: + logger.warning( + "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embeding" + f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available." + " For more details about this, or help on choosing the correct value for resizing, refer to this guide:" + " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc" + ) + if new_num_tokens is None: return old_embeddings diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 368c0b5e01..32c16de0dd 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -1077,18 +1077,25 @@ class BarkFineModel(BarkPreTrainedModel): # one lm_head for each codebook self.lm_heads = new_output_embeddings - def _resize_token_embeddings(self, new_num_tokens): + def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): old_embeddings_list = self.get_input_embeddings() new_embeddings_list = nn.ModuleList( - [self._get_resized_embeddings(old_embeddings, new_num_tokens) for old_embeddings in old_embeddings_list] + [ + self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of) + for old_embeddings in old_embeddings_list + ] ) self.set_input_embeddings(new_embeddings_list) + new_num_tokens = [embed.weight.shape[0] for embed in new_embeddings_list] # if word embeddings are not tied, make sure that lm head is resized as well if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings: old_lm_head_list = self.get_output_embeddings() new_lm_head_list = nn.ModuleList( - [self._get_resized_lm_head(old_lm_head, new_num_tokens) for old_lm_head in old_lm_head_list] + [ + self._get_resized_lm_head(old_lm_head, new_num_token) + for old_lm_head, new_num_token in zip(old_lm_head_list, new_num_tokens) + ] ) self.set_output_embeddings(new_lm_head_list) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index fe3fbb1f8c..09ec877022 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1324,9 +1324,9 @@ class BartForConditionalGeneration(BartPreTrainedModel): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - self._resize_final_logits_bias(new_num_tokens) + def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings def _resize_final_logits_bias(self, new_num_tokens: int) -> None: diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 43f876a7ee..0e0dda8b69 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2508,9 +2508,9 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - self._resize_final_logits_bias(new_num_tokens) + def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings def _resize_final_logits_bias(self, new_num_tokens: int) -> None: diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index de34037868..20779fe66a 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1277,9 +1277,9 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - self._resize_final_logits_bias(new_num_tokens) + def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings def _resize_final_logits_bias(self, new_num_tokens: int) -> None: diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index ff6eda893f..dc2f1512ff 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -1244,9 +1244,9 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - self._resize_final_logits_bias(new_num_tokens) + def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings def _resize_final_logits_bias(self, new_num_tokens: int) -> None: diff --git a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py index f02aa2dc83..0d9301406d 100644 --- a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py @@ -1313,9 +1313,9 @@ class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel): return self._shift_right(labels) # Copied from transformers.models.mbart.modeling_mbart.MBartForConditionalGeneration.resize_token_embeddings with MBart->GPTSanJapanese - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - self._resize_final_logits_bias(new_num_tokens) + def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings # Copied from transformers.models.mbart.modeling_mbart.MBartForConditionalGeneration._resize_final_logits_bias with MBart->GPTSanJapanese diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 3cf827fe33..e405098bf0 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -2352,9 +2352,9 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): def get_decoder(self): return self.led.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - self._resize_final_logits_bias(new_num_tokens) + def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings def _resize_final_logits_bias(self, new_num_tokens: int) -> None: diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 116edbb27b..5d9dbccffd 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -1267,10 +1267,6 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - return new_embeddings - def get_output_embeddings(self): return self.lm_head diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 316aa57518..6c287151ee 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1316,17 +1316,18 @@ class MarianMTModel(MarianPreTrainedModel): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) + def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) if self.config.share_encoder_decoder_embeddings: self._resize_final_logits_bias(new_num_tokens) return new_embeddings - def _resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: + def _resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of=None) -> nn.Embedding: old_embeddings = self.get_input_embeddings() - new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) + new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of) self.set_input_embeddings(new_embeddings) + new_num_tokens = new_embeddings.weight.shape[0] # update config.decoder_vocab_size if embeddings are tied if self.config.share_encoder_decoder_embeddings: self.config.decoder_vocab_size = new_num_tokens diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 5319e47680..6faa2d7bc7 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1294,9 +1294,9 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - self._resize_final_logits_bias(new_num_tokens) + def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings def _resize_final_logits_bias(self, new_num_tokens: int) -> None: diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 4d03ecdcaa..c42ef51c53 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -1453,8 +1453,8 @@ class MvpForConditionalGeneration(MvpPreTrainedModel): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) + def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) self._resize_final_logits_bias(new_num_tokens) return new_embeddings diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index b37c79ddfc..a7c02cdeba 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -1652,10 +1652,6 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - return new_embeddings - def get_output_embeddings(self): return self.lm_head diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 7cccf0df19..b64833a8f6 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1327,9 +1327,9 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - self._resize_final_logits_bias(new_num_tokens) + def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings def _resize_final_logits_bias(self, new_num_tokens: int) -> None: diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index ba3b98b1ef..53e920e365 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1552,10 +1552,6 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - return new_embeddings - def get_output_embeddings(self): return self.lm_head diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index f01a1735f3..4271a37ee5 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1267,9 +1267,9 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - self._resize_final_logits_bias(new_num_tokens) + def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings def _resize_final_logits_bias(self, new_num_tokens: int) -> None: diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 4413837819..60889972a5 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -1282,10 +1282,6 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - return new_embeddings - def get_output_embeddings(self): return self.lm_head diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 8471cb76ae..df31075192 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -2359,10 +2359,6 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel): """ self.get_encoder().prenet.freeze_feature_encoder() - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - return new_embeddings - def get_output_embeddings(self): return self.text_decoder_postnet.get_output_embeddings() diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 7aac3e1c18..926101156d 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1410,10 +1410,6 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - return new_embeddings - def get_output_embeddings(self): return self.proj_out diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index e4a3f2de60..eed704d3bc 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1413,6 +1413,26 @@ class ModelTesterMixin: self.assertTrue(models_equal) + config = copy.deepcopy(original_config) + model = model_class(config) + model.to(torch_device) + + model_vocab_size = config.vocab_size + model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1) + self.assertTrue(model.config.vocab_size + 10, model_vocab_size) + + model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64) + self.assertTrue(model_embed.weight.shape[0] // 64, 0) + + model_embed = model.resize_token_embeddings(model_vocab_size + 13, pad_to_multiple_of=64) + self.assertTrue(model_embed.weight.shape[0] // 64, 0) + + with self.assertRaisesRegex( + ValueError, + "Asking to pad the embedding matrix to a multiple of `1.3`, which is not and integer. Please make sure to pass an integer", + ): + model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3) + def test_resize_embeddings_untied(self): ( original_config,