From 38a4bf79ad4e390d0b6375eed59565177ae61f26 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 1 May 2024 12:33:00 +0500 Subject: [PATCH] Encoder-decoder models: move embedding scale to nn.Module (#30410) * move scaling to nn.Module * let the test be here for now (need to fix) * failing tests * last failing models * Revert commit 4c14817f38 * clean-up * oops forgot * codestyle * raise NotImplemented when possible * Update tests/test_modeling_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * skip tests in respective modeling files --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/models/bart/modeling_bart.py | 29 ++++++++--- .../modeling_bigbird_pegasus.py | 35 ++++++++++--- .../models/biogpt/modeling_biogpt.py | 22 ++++++-- .../models/blenderbot/modeling_blenderbot.py | 30 ++++++++--- .../bridgetower/modeling_bridgetower.py | 5 ++ .../models/funnel/modeling_funnel.py | 6 +-- .../modeling_gptsan_japanese.py | 4 ++ .../models/m2m_100/modeling_m2m_100.py | 33 +++++++++--- .../models/mbart/modeling_mbart.py | 30 ++++++++--- .../models/nllb_moe/modeling_nllb_moe.py | 33 +++++++++--- .../models/pegasus_x/modeling_pegasus_x.py | 38 +++++++++++--- .../models/plbart/modeling_plbart.py | 33 +++++++++--- .../seamless_m4t/modeling_seamless_m4t.py | 34 ++++++++++--- .../modeling_seamless_m4t_v2.py | 34 ++++++++++--- .../models/trocr/modeling_trocr.py | 22 ++++++-- src/transformers/models/xglm/modeling_xglm.py | 22 ++++++-- tests/models/align/test_modeling_align.py | 12 +++++ tests/models/bark/test_modeling_bark.py | 50 +++++++++++++++++++ .../bridgetower/test_modeling_bridgetower.py | 4 ++ tests/models/canine/test_modeling_canine.py | 4 ++ .../test_modeling_conditional_detr.py | 4 ++ .../test_modeling_deformable_detr.py | 4 ++ tests/models/deta/test_modeling_deta.py | 4 ++ tests/models/detr/test_modeling_detr.py | 4 ++ tests/models/fsmt/test_modeling_fsmt.py | 4 ++ .../test_modeling_gptsan_japanese.py | 16 ++++++ tests/models/ibert/test_modeling_ibert.py | 4 ++ .../models/idefics2/test_modeling_idefics2.py | 4 ++ .../models/imagegpt/test_modeling_imagegpt.py | 25 ++++++++++ .../models/musicgen/test_modeling_musicgen.py | 4 ++ .../test_modeling_musicgen_melody.py | 4 ++ .../test_modeling_seamless_m4t.py | 4 ++ .../test_modeling_seamless_m4t_v2.py | 4 ++ .../test_modeling_table_transformer.py | 4 ++ tests/models/vilt/test_modeling_vilt.py | 7 +++ tests/test_modeling_common.py | 45 +++++++++++++++++ 36 files changed, 541 insertions(+), 80 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 630688d1fd..f44286bb08 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -132,6 +132,19 @@ class BartLearnedPositionalEmbedding(nn.Embedding): return super().forward(positions + self.offset) +class BartScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + class BartAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -1056,9 +1069,11 @@ class BartEncoder(BartPreTrainedModel): embed_dim = config.d_model self.padding_idx = config.pad_token_id self.max_source_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens = BartScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) if embed_tokens is not None: self.embed_tokens.weight = embed_tokens.weight @@ -1146,7 +1161,7 @@ class BartEncoder(BartPreTrainedModel): raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) embed_pos = self.embed_positions(input) embed_pos = embed_pos.to(inputs_embeds.device) @@ -1238,9 +1253,11 @@ class BartDecoder(BartPreTrainedModel): self.layerdrop = config.decoder_layerdrop self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_tokens = BartScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) if embed_tokens is not None: self.embed_tokens.weight = embed_tokens.weight @@ -1369,7 +1386,7 @@ class BartDecoder(BartPreTrainedModel): past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input) * self.embed_scale + inputs_embeds = self.embed_tokens(input) if self._use_flash_attention_2: # 2d mask is passed through the layers diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index b863beb75e..6ea7a822d7 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -90,6 +90,20 @@ class BigBirdPegasusLearnedPositionalEmbedding(nn.Embedding): return super().forward(positions) +# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->BigBirdPegasus +class BigBirdPegasusScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + # Copied from transformers.models.big_bird.modeling_big_bird.BigBirdSelfAttention with BigBird->BigBirdPegasus class BigBirdPegasusSelfAttention(nn.Module): def __init__(self, config): @@ -1749,9 +1763,11 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): embed_dim = config.d_model self.padding_idx = config.pad_token_id self.max_source_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens = BigBirdPegasusScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) if embed_tokens is not None: self.embed_tokens.weight = embed_tokens.weight @@ -1827,7 +1843,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) embed_pos = self.embed_positions(input_shape) @@ -2042,9 +2058,11 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): self.layerdrop = config.decoder_layerdrop self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_tokens = BigBirdPegasusScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) if embed_tokens is not None: self.embed_tokens.weight = embed_tokens.weight @@ -2168,7 +2186,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length @@ -2292,7 +2310,10 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel): super().__init__(config) padding_idx, vocab_size = config.pad_token_id, config.vocab_size - self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.shared = BigBirdPegasusScaledWordEmbedding( + vocab_size, config.d_model, padding_idx, embed_scale=embed_scale + ) self.encoder = BigBirdPegasusEncoder(config, self.shared) self.decoder = BigBirdPegasusDecoder(config, self.shared) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 30df3e0847..8a94105081 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -75,6 +75,20 @@ class BioGptLearnedPositionalEmbedding(nn.Embedding): return super().forward(positions + self.offset) +# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->BioGpt +class BioGptScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BioGpt class BioGptAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -423,9 +437,11 @@ class BioGptModel(BioGptPreTrainedModel): self.dropout = config.hidden_dropout_prob self.embed_dim = config.hidden_size self.padding_idx = config.pad_token_id - self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 - self.embed_tokens = nn.Embedding(config.vocab_size, self.embed_dim, self.padding_idx) + self.embed_tokens = BioGptScaledWordEmbedding( + config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale + ) self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim) self.layers = nn.ModuleList([BioGptDecoderLayer(config) for _ in range(config.num_hidden_layers)]) @@ -482,7 +498,7 @@ class BioGptModel(BioGptPreTrainedModel): past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input) * self.embed_scale + inputs_embeds = self.embed_tokens(input) if attention_mask is None: attention_mask = torch.ones( diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 5fa17abcdd..1fd545691d 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -90,6 +90,20 @@ class BlenderbotLearnedPositionalEmbedding(nn.Embedding): return super().forward(positions) +# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->Blenderbot +class BlenderbotScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Blenderbot class BlenderbotAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -632,12 +646,14 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): embed_dim = config.d_model self.padding_idx = config.pad_token_id self.max_source_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 if embed_tokens is not None: self.embed_tokens = embed_tokens else: - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens = BlenderbotScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) self.embed_positions = BlenderbotLearnedPositionalEmbedding( config.max_position_embeddings, @@ -715,7 +731,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) embed_pos = self.embed_positions(input_shape) @@ -799,12 +815,14 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): self.layerdrop = config.decoder_layerdrop self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 if embed_tokens is not None: self.embed_tokens = embed_tokens else: - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_tokens = BlenderbotScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) self.embed_positions = BlenderbotLearnedPositionalEmbedding( config.max_position_embeddings, @@ -926,7 +944,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 3fc9f755aa..6bbb043546 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -1325,6 +1325,11 @@ class BridgeTowerModel(BridgeTowerPreTrainedModel): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None + if inputs_embeds is not None and input_ids is None: + raise NotImplementedError( + "BridgeTowerModel does not use `inputs_embeds`. Make sure to pass in `input_ids` instead." + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict image_token_type_idx = image_token_type_idx if image_token_type_idx else 1 input_shape = input_ids.size() diff --git a/src/transformers/models/funnel/modeling_funnel.py b/src/transformers/models/funnel/modeling_funnel.py index ce0c778948..50e98e4c04 100644 --- a/src/transformers/models/funnel/modeling_funnel.py +++ b/src/transformers/models/funnel/modeling_funnel.py @@ -972,8 +972,7 @@ class FunnelBaseModel(FunnelPreTrainedModel): token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) # TODO: deal with head_mask - if inputs_embeds is None: - inputs_embeds = self.embeddings(input_ids) + inputs_embeds = self.embeddings(input_ids, inputs_embeds=inputs_embeds) encoder_outputs = self.encoder( inputs_embeds, @@ -1048,8 +1047,7 @@ class FunnelModel(FunnelPreTrainedModel): token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) # TODO: deal with head_mask - if inputs_embeds is None: - inputs_embeds = self.embeddings(input_ids) + inputs_embeds = self.embeddings(input_ids, inputs_embeds=inputs_embeds) encoder_outputs = self.encoder( inputs_embeds, diff --git a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py index 59252bc567..2582d0468d 100644 --- a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py @@ -920,6 +920,10 @@ class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel): device = self.position_embeddings.weight.device if input_ids is None: input_ids = torch.zeros([1, 1]).int().to(device) # dummy for input_ids was None + if inputs_embeds is not None: + raise NotImplementedError( + "GPTSanJapaneseModel does not use `inputs_embeds`. Make sure to pass in `input_ids` instead." + ) num_pasts_contexts = 0 num_batch = input_ids.shape[0] pasts_or_spout_value = None diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 1517610b06..1080d28c94 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -87,6 +87,20 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l return incremental_indices.long() + padding_idx +# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->M2M100 +class M2M100ScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + class M2M100SinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length.""" @@ -886,9 +900,11 @@ class M2M100Encoder(M2M100PreTrainedModel): embed_dim = config.d_model self.padding_idx = config.pad_token_id self.max_source_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens = M2M100ScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) if embed_tokens is not None: self.embed_tokens.weight = embed_tokens.weight @@ -971,7 +987,7 @@ class M2M100Encoder(M2M100PreTrainedModel): raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) embed_pos = self.embed_positions(input_ids, inputs_embeds) embed_pos = embed_pos.to(inputs_embeds.device) @@ -1061,9 +1077,11 @@ class M2M100Decoder(M2M100PreTrainedModel): self.layerdrop = config.decoder_layerdrop self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_tokens = M2M100ScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) if embed_tokens is not None: self.embed_tokens.weight = embed_tokens.weight @@ -1183,7 +1201,7 @@ class M2M100Decoder(M2M100PreTrainedModel): past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) if self._use_flash_attention_2: # 2d mask is passed through the layers @@ -1321,7 +1339,8 @@ class M2M100Model(M2M100PreTrainedModel): super().__init__(config) padding_idx, vocab_size = config.pad_token_id, config.vocab_size - self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.shared = M2M100ScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) self.encoder = M2M100Encoder(config, self.shared) self.decoder = M2M100Decoder(config, self.shared) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index fc23e2c675..34eab5b8f3 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -118,6 +118,20 @@ class MBartLearnedPositionalEmbedding(nn.Embedding): return super().forward(positions + self.offset) +# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->MBart +class MBartScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->MBart class MBartAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -919,9 +933,11 @@ class MBartEncoder(MBartPreTrainedModel): embed_dim = config.d_model self.padding_idx = config.pad_token_id self.max_source_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens = MBartScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) if embed_tokens is not None: self.embed_tokens.weight = embed_tokens.weight @@ -1009,7 +1025,7 @@ class MBartEncoder(MBartPreTrainedModel): raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) embed_pos = self.embed_positions(input) @@ -1097,9 +1113,11 @@ class MBartDecoder(MBartPreTrainedModel): self.layerdrop = config.decoder_layerdrop self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_tokens = MBartScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) if embed_tokens is not None: self.embed_tokens.weight = embed_tokens.weight @@ -1227,7 +1245,7 @@ class MBartDecoder(MBartPreTrainedModel): past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) if self._use_flash_attention_2: # 2d mask is passed through the layers diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 4ef66b7bd5..e8c827b608 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -133,6 +133,20 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2) +# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->NllbMoe +class NllbMoeScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding class NllbMoeSinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length.""" @@ -992,9 +1006,11 @@ class NllbMoeEncoder(NllbMoePreTrainedModel): embed_dim = config.d_model self.padding_idx = config.pad_token_id self.max_source_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens = NllbMoeScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) if embed_tokens is not None: self.embed_tokens.weight = embed_tokens.weight @@ -1085,7 +1101,7 @@ class NllbMoeEncoder(NllbMoePreTrainedModel): raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) embed_pos = self.embed_positions(input_ids, inputs_embeds) embed_pos = embed_pos.to(inputs_embeds.device) @@ -1178,9 +1194,11 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): self.layerdrop = config.decoder_layerdrop self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_tokens = NllbMoeScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) if embed_tokens is not None: self.embed_tokens.weight = embed_tokens.weight @@ -1309,7 +1327,7 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -1458,7 +1476,8 @@ class NllbMoeModel(NllbMoePreTrainedModel): super().__init__(config) padding_idx, vocab_size = config.pad_token_id, config.vocab_size - self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.shared = NllbMoeScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) self.encoder = NllbMoeEncoder(config, self.shared) self.decoder = NllbMoeDecoder(config, self.shared) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index f31ccccbb1..2a5e9a1fc2 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -87,6 +87,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids +# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->PegasusX +class PegasusXScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + class PegasusXSinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length.""" @@ -880,13 +894,16 @@ class PegasusXEncoder(PegasusXPreTrainedModel): self.layerdrop = config.encoder_layerdrop embed_dim = config.d_model + padding_idx = config.pad_token_id self.max_source_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 if embed_tokens is not None: self.embed_tokens = embed_tokens else: - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim) + self.embed_tokens = PegasusXScaledWordEmbedding( + config.vocab_size, embed_dim, padding_idx, embed_scale=embed_scale + ) self.embed_global = nn.Embedding(config.num_global_tokens, embed_dim) self.embed_positions = PegasusXSinusoidalPositionalEmbedding(embed_dim) @@ -988,7 +1005,7 @@ class PegasusXEncoder(PegasusXPreTrainedModel): raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) embed_pos = self.embed_positions(inputs_embeds) @@ -1086,12 +1103,15 @@ class PegasusXDecoder(PegasusXPreTrainedModel): self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop self.max_target_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + padding_idx = config.pad_token_id if embed_tokens is not None: self.embed_tokens = embed_tokens else: - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) + self.embed_tokens = PegasusXScaledWordEmbedding( + config.vocab_size, config.d_model, padding_idx=padding_idx, embed_scale=embed_scale + ) self.embed_positions = PegasusXSinusoidalPositionalEmbedding(config.d_model) self.layers = nn.ModuleList([PegasusXDecoderLayer(config) for _ in range(config.decoder_layers)]) @@ -1196,7 +1216,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel): past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length @@ -1307,7 +1327,11 @@ class PegasusXModel(PegasusXPreTrainedModel): super().__init__(config) vocab_size = config.vocab_size - self.shared = nn.Embedding(vocab_size, config.d_model) + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + padding_idx = config.pad_token_id + self.shared = PegasusXScaledWordEmbedding( + vocab_size, config.d_model, padding_idx=padding_idx, embed_scale=embed_scale + ) self.encoder = PegasusXEncoder(config, self.shared) self.decoder = PegasusXDecoder(config, self.shared) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index d60b7ee4b0..78200e92eb 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -102,6 +102,20 @@ class PLBartLearnedPositionalEmbedding(nn.Embedding): return super().forward(positions + self.offset) +# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->PLBart +class PLBartScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PLBart class PLBartAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -658,9 +672,11 @@ class PLBartEncoder(PLBartPreTrainedModel): embed_dim = config.d_model self.padding_idx = config.pad_token_id self.max_source_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens = PLBartScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) if embed_tokens is not None: self.embed_tokens.weight = embed_tokens.weight @@ -748,7 +764,7 @@ class PLBartEncoder(PLBartPreTrainedModel): raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) embed_pos = self.embed_positions(input) embed_pos = embed_pos.to(inputs_embeds.device) @@ -841,9 +857,11 @@ class PLBartDecoder(PLBartPreTrainedModel): self.layerdrop = config.decoder_layerdrop self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_tokens = PLBartScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) if embed_tokens is not None: self.embed_tokens.weight = embed_tokens.weight @@ -972,7 +990,7 @@ class PLBartDecoder(PLBartPreTrainedModel): past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input) * self.embed_scale + inputs_embeds = self.embed_tokens(input) if self._use_flash_attention_2: # 2d mask is passed through the layers @@ -1122,7 +1140,8 @@ class PLBartModel(PLBartPreTrainedModel): super().__init__(config) padding_idx, vocab_size = config.pad_token_id, config.vocab_size - self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.shared = PLBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) self.encoder = PLBartEncoder(config, self.shared) self.decoder = PLBartDecoder(config, self.shared) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index c0fe60a643..bfc0fb5aeb 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -989,6 +989,20 @@ class SeamlessM4TConformerAdapter(nn.Module): ############ TEXT / UNITS related code ################ +# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->SeamlessM4T +class SeamlessM4TScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding class SeamlessM4TSinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length.""" @@ -1631,9 +1645,11 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel): self.max_source_positions = config.max_position_embeddings if not self.is_t2u_encoder: - self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens = SeamlessM4TScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) if embed_tokens is not None: self.embed_tokens.weight = embed_tokens.weight @@ -1726,7 +1742,7 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel): raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) if not self.is_t2u_encoder: embed_pos = self.embed_positions(input) @@ -1809,14 +1825,18 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.max_target_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 if embed_tokens is not None: # if embed_tokens defined, use its shape instead - self.embed_tokens = nn.Embedding(embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx) + self.embed_tokens = SeamlessM4TScaledWordEmbedding( + embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx, embed_scale=embed_scale + ) self.embed_tokens.weight = embed_tokens.weight else: - self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx) + self.embed_tokens = SeamlessM4TScaledWordEmbedding( + self.vocab_size, config.hidden_size, self.padding_idx, embed_scale=embed_scale + ) self.embed_positions = SeamlessM4TSinusoidalPositionalEmbedding( self.max_target_positions, @@ -1935,7 +1955,7 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel): past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index c7f90f6c0a..238e089c31 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -946,6 +946,20 @@ class SeamlessM4Tv2ConformerAdapter(nn.Module): ############ TEXT / UNITS related code ################ +# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->SeamlessM4Tv2 +class SeamlessM4Tv2ScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding class SeamlessM4Tv2SinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length.""" @@ -1753,9 +1767,11 @@ class SeamlessM4Tv2Encoder(SeamlessM4Tv2PreTrainedModel): self.max_source_positions = config.max_position_embeddings if not self.is_t2u_encoder: - self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens = SeamlessM4Tv2ScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) if embed_tokens is not None: self.embed_tokens.weight = embed_tokens.weight @@ -1848,7 +1864,7 @@ class SeamlessM4Tv2Encoder(SeamlessM4Tv2PreTrainedModel): raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) if not self.is_t2u_encoder: embed_pos = self.embed_positions(input) @@ -1932,14 +1948,18 @@ class SeamlessM4Tv2Decoder(SeamlessM4Tv2PreTrainedModel): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.max_target_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 if embed_tokens is not None: # if embed_tokens defined, use its shape instead - self.embed_tokens = nn.Embedding(embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx) + self.embed_tokens = SeamlessM4Tv2ScaledWordEmbedding( + embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx, embed_scale=embed_scale + ) self.embed_tokens.weight = embed_tokens.weight else: - self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx) + self.embed_tokens = SeamlessM4Tv2ScaledWordEmbedding( + self.vocab_size, config.hidden_size, self.padding_idx, embed_scale=embed_scale + ) self.embed_positions = SeamlessM4Tv2SinusoidalPositionalEmbedding( self.max_target_positions, @@ -2058,7 +2078,7 @@ class SeamlessM4Tv2Decoder(SeamlessM4Tv2PreTrainedModel): past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index c80171292b..a20c56e331 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -63,6 +63,20 @@ class TrOCRLearnedPositionalEmbedding(nn.Embedding): return super().forward(positions + self.offset) +# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->TrOCR +class TrOCRScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + class TrOCRSinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length.""" @@ -451,9 +465,11 @@ class TrOCRDecoder(TrOCRPreTrainedModel): self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop self.padding_idx = config.pad_token_id - self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_tokens = TrOCRScaledWordEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=embed_scale + ) if config.use_learned_position_embeddings: self.embed_positions = TrOCRLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size) @@ -584,7 +600,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel): past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) if self.config.use_learned_position_embeddings: embed_pos = self.embed_positions(input, past_key_values_length=past_key_values_length) diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 7ec48b6f9d..538c852ae9 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -127,6 +127,20 @@ XGLM_INPUTS_DOCSTRING = r""" """ +# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->XGLM +class XGLMScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + class XGLMSinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length.""" @@ -490,12 +504,14 @@ class XGLMModel(XGLMPreTrainedModel): self.layerdrop = config.layerdrop self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 if embed_tokens is not None: self.embed_tokens = embed_tokens else: - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_tokens = XGLMScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) self.embed_positions = XGLMSinusoidalPositionalEmbedding( config.max_position_embeddings, @@ -568,7 +584,7 @@ class XGLMModel(XGLMPreTrainedModel): position_ids = position_ids.unsqueeze(0) if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length diff --git a/tests/models/align/test_modeling_align.py b/tests/models/align/test_modeling_align.py index ee50a1a74b..cf6962a3c6 100644 --- a/tests/models/align/test_modeling_align.py +++ b/tests/models/align/test_modeling_align.py @@ -167,6 +167,10 @@ class AlignVisionModelTest(ModelTesterMixin, unittest.TestCase): def test_inputs_embeds(self): pass + @unittest.skip(reason="AlignVisionModel does not use inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + @unittest.skip(reason="AlignVisionModel does not support input and output embeddings") def test_model_common_attributes(self): pass @@ -379,6 +383,10 @@ class AlignTextModelTest(ModelTesterMixin, unittest.TestCase): def test_inputs_embeds(self): pass + @unittest.skip(reason="Align does not use inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + @unittest.skip(reason="AlignTextModel has no base class and is not available in MODEL_MAPPING") def test_save_load_fast_init_from_base(self): pass @@ -473,6 +481,10 @@ class AlignModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): def test_inputs_embeds(self): pass + @unittest.skip(reason="Align does not use inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + @unittest.skip(reason="Retain_grad is tested in individual model tests") def test_retain_grad_hidden_states_attentions(self): pass diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 04a6ad99b8..476031068f 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -579,6 +579,29 @@ class BarkSemanticModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Te with torch.no_grad(): model(**inputs)[0] + # override as the input arg is called "input_embeds", not "inputs_embeds" + def test_inputs_embeds_matches_input_ids(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + with torch.no_grad(): + out_ids = model(**inputs)[0] + + input_ids = inputs["input_ids"] + del inputs["input_ids"] + + wte = model.get_input_embeddings() + inputs["input_embeds"] = wte(input_ids) + with torch.no_grad(): + out_embeds = model(**inputs)[0] + + self.assertTrue(torch.allclose(out_embeds, out_ids)) + @require_torch_fp16 def test_generate_fp16(self): config, input_dict = self.model_tester.prepare_config_and_inputs() @@ -645,6 +668,29 @@ class BarkCoarseModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test with torch.no_grad(): model(**inputs)[0] + # override as the input arg is called "input_embeds", not "inputs_embeds" + def test_inputs_embeds_matches_input_ids(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + with torch.no_grad(): + out_ids = model(**inputs)[0] + + input_ids = inputs["input_ids"] + del inputs["input_ids"] + + wte = model.get_input_embeddings() + inputs["input_embeds"] = wte(input_ids) + with torch.no_grad(): + out_embeds = model(**inputs)[0] + + self.assertTrue(torch.allclose(out_embeds, out_ids)) + @require_torch_fp16 def test_generate_fp16(self): config, input_dict = self.model_tester.prepare_config_and_inputs() @@ -709,6 +755,10 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): with torch.no_grad(): model(**inputs)[0] + @unittest.skip("FineModel relies on codebook idx and does not return same logits") + def test_inputs_embeds_matches_input_ids(self): + pass + @require_torch_fp16 def test_generate_fp16(self): config, input_dict = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/bridgetower/test_modeling_bridgetower.py b/tests/models/bridgetower/test_modeling_bridgetower.py index 971ea4f08a..c9ce8f076f 100644 --- a/tests/models/bridgetower/test_modeling_bridgetower.py +++ b/tests/models/bridgetower/test_modeling_bridgetower.py @@ -506,6 +506,10 @@ class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC def test_inputs_embeds(self): pass + @unittest.skip(reason="Bridge Tower does not use inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/models/canine/test_modeling_canine.py b/tests/models/canine/test_modeling_canine.py index eeb5aa40dd..5c342ee975 100644 --- a/tests/models/canine/test_modeling_canine.py +++ b/tests/models/canine/test_modeling_canine.py @@ -502,6 +502,10 @@ class CanineModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): # ViT does not use inputs_embeds pass + @unittest.skip(reason="Canine Tower does not use inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + @unittest.skip("CANINE does not have a get_input_embeddings() method.") def test_model_common_attributes(self): pass diff --git a/tests/models/conditional_detr/test_modeling_conditional_detr.py b/tests/models/conditional_detr/test_modeling_conditional_detr.py index c3f77614b4..a01acffafb 100644 --- a/tests/models/conditional_detr/test_modeling_conditional_detr.py +++ b/tests/models/conditional_detr/test_modeling_conditional_detr.py @@ -247,6 +247,10 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline def test_inputs_embeds(self): pass + @unittest.skip(reason="Conditional DETR does not use inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + @unittest.skip(reason="Conditional DETR does not have a get_input_embeddings method") def test_model_common_attributes(self): pass diff --git a/tests/models/deformable_detr/test_modeling_deformable_detr.py b/tests/models/deformable_detr/test_modeling_deformable_detr.py index 36be099790..23e95267d8 100644 --- a/tests/models/deformable_detr/test_modeling_deformable_detr.py +++ b/tests/models/deformable_detr/test_modeling_deformable_detr.py @@ -253,6 +253,10 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT def test_inputs_embeds(self): pass + @unittest.skip(reason="Deformable DETR does not use inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + @unittest.skip(reason="Deformable DETR does not have a get_input_embeddings method") def test_model_common_attributes(self): pass diff --git a/tests/models/deta/test_modeling_deta.py b/tests/models/deta/test_modeling_deta.py index 655bb50bb5..70c1009a50 100644 --- a/tests/models/deta/test_modeling_deta.py +++ b/tests/models/deta/test_modeling_deta.py @@ -303,6 +303,10 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin def test_inputs_embeds(self): pass + @unittest.skip(reason="DETA does not use inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + @unittest.skip(reason="DETA does not have a get_input_embeddings method") def test_model_common_attributes(self): pass diff --git a/tests/models/detr/test_modeling_detr.py b/tests/models/detr/test_modeling_detr.py index 27092c626d..ee1af9ed9d 100644 --- a/tests/models/detr/test_modeling_detr.py +++ b/tests/models/detr/test_modeling_detr.py @@ -247,6 +247,10 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin def test_inputs_embeds(self): pass + @unittest.skip(reason="DETR does not use inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + @unittest.skip(reason="DETR does not have a get_input_embeddings method") def test_model_common_attributes(self): pass diff --git a/tests/models/fsmt/test_modeling_fsmt.py b/tests/models/fsmt/test_modeling_fsmt.py index da73b8d41d..cd5d479cd5 100644 --- a/tests/models/fsmt/test_modeling_fsmt.py +++ b/tests/models/fsmt/test_modeling_fsmt.py @@ -321,6 +321,10 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin def test_inputs_embeds(self): pass + @unittest.skip("Input ids is required for FSMT.") + def test_inputs_embeds_matches_input_ids(self): + pass + @unittest.skip("model weights aren't tied in FSMT.") def test_tie_model_weights(self): pass diff --git a/tests/models/gptsan_japanese/test_modeling_gptsan_japanese.py b/tests/models/gptsan_japanese/test_modeling_gptsan_japanese.py index 716b8b9fc6..177323f747 100644 --- a/tests/models/gptsan_japanese/test_modeling_gptsan_japanese.py +++ b/tests/models/gptsan_japanese/test_modeling_gptsan_japanese.py @@ -182,6 +182,14 @@ class GPTSanJapaneseTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas def test_model_parallelism(self): super().test_model_parallelism() + @unittest.skip(reason="Gptsan does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Gptsan does not use inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + @require_torch class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): @@ -212,6 +220,14 @@ class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTes def test_model_parallelism(self): super().test_model_parallelism() + @unittest.skip(reason="Gptsan does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Gptsan does not use inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + @slow def test_logits(self): model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese") diff --git a/tests/models/ibert/test_modeling_ibert.py b/tests/models/ibert/test_modeling_ibert.py index fd3809acff..ec043a96b8 100644 --- a/tests/models/ibert/test_modeling_ibert.py +++ b/tests/models/ibert/test_modeling_ibert.py @@ -382,6 +382,10 @@ class IBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): with torch.no_grad(): model(**inputs)[0] + @unittest.skip("ibert overrides scaling to None if inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + @require_torch class IBertModelIntegrationTest(unittest.TestCase): diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py index 5553c972e6..63e6316773 100644 --- a/tests/models/idefics2/test_modeling_idefics2.py +++ b/tests/models/idefics2/test_modeling_idefics2.py @@ -180,6 +180,10 @@ class Idefics2ModelTest(ModelTesterMixin, unittest.TestCase): def test_inputs_embeds(): pass + @unittest.skip("input_embeds cannot be passed in without input_ids") + def test_inputs_embeds_matches_input_ids(self): + pass + @unittest.skip("Model does not support padding right") def test_flash_attn_2_generate_padding_right(self): pass diff --git a/tests/models/imagegpt/test_modeling_imagegpt.py b/tests/models/imagegpt/test_modeling_imagegpt.py index e18f745335..afb5ce8776 100644 --- a/tests/models/imagegpt/test_modeling_imagegpt.py +++ b/tests/models/imagegpt/test_modeling_imagegpt.py @@ -466,6 +466,31 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM with torch.no_grad(): model(**inputs)[0] + # override because ImageGPT main input name is `pixel_values` + # NOTE: in latest transformers this is deprecated, `input_ids` should be used. TODO + def test_inputs_embeds_matches_input_ids(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + with torch.no_grad(): + out_ids = model(**inputs)[0] + + pixel_values = inputs["pixel_values"] + del inputs["pixel_values"] + + wte = model.get_input_embeddings() + inputs["inputs_embeds"] = wte(pixel_values) + + with torch.no_grad(): + out_embeds = model(**inputs)[0] + + self.assertTrue(torch.allclose(out_embeds, out_ids)) + def _create_and_check_torchscript(self, config, inputs_dict): if not self.test_torchscript: return diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index 5ac9c97479..b04f99c05e 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -265,6 +265,10 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste lm_heads = model.get_output_embeddings() self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear)) + @unittest.skip(reason="MusicGen does not use inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + # skip as this model doesn't support all arguments tested def test_model_outputs_equivalence(self): pass diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index 98d8cc0b9f..628cc76a09 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -268,6 +268,10 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes lm_heads = model.get_output_embeddings() self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear)) + @unittest.skip(reason="MusicGen melody does not use inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + @unittest.skip("this model doesn't support all arguments tested") def test_model_outputs_equivalence(self): pass diff --git a/tests/models/seamless_m4t/test_modeling_seamless_m4t.py b/tests/models/seamless_m4t/test_modeling_seamless_m4t.py index c08e559057..d77aac6187 100644 --- a/tests/models/seamless_m4t/test_modeling_seamless_m4t.py +++ b/tests/models/seamless_m4t/test_modeling_seamless_m4t.py @@ -463,6 +463,10 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase): def test_inputs_embeds(self): pass + @unittest.skip(reason="SeamlessM4TSpeechEncoder doesn't have an embedding layer") + def test_inputs_embeds_matches_input_ids(self): + pass + @unittest.skip( reason="Expected missing keys serve when using SeamlessM4TForXXX.from_pretrained from a checkpoint saved by SeamlessM4TModel.save_pretrained." ) diff --git a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py index 699641fcfd..301a1eb44b 100644 --- a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py +++ b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py @@ -479,6 +479,10 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase) def test_inputs_embeds(self): pass + @unittest.skip(reason="SeamlessM4TSpeechEncoder doesn't have an embedding layer") + def test_inputs_embeds_matches_input_ids(self): + pass + @unittest.skip( reason="Expected missing keys serve when using SeamlessM4Tv2ForXXX.from_pretrained from a checkpoint saved by SeamlessM4Tv2Model.save_pretrained." ) diff --git a/tests/models/table_transformer/test_modeling_table_transformer.py b/tests/models/table_transformer/test_modeling_table_transformer.py index d323083eb7..989517eb8c 100644 --- a/tests/models/table_transformer/test_modeling_table_transformer.py +++ b/tests/models/table_transformer/test_modeling_table_transformer.py @@ -261,6 +261,10 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, Pipelin def test_inputs_embeds(self): pass + @unittest.skip(reason="Table Transformer does not use inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + @unittest.skip(reason="Table Transformer does not have a get_input_embeddings method") def test_model_common_attributes(self): pass diff --git a/tests/models/vilt/test_modeling_vilt.py b/tests/models/vilt/test_modeling_vilt.py index 4c877c2e18..3e25fc3bba 100644 --- a/tests/models/vilt/test_modeling_vilt.py +++ b/tests/models/vilt/test_modeling_vilt.py @@ -357,6 +357,13 @@ class ViltModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): def test_model_outputs_equivalence(self): pass + @unittest.skip( + reason="""VilT samples image tokens from a multinomial distribution, resulting in not deterministic + hidden states. Cannot test equivalence on logit level""" + ) + def test_inputs_embeds_matches_input_ids(self): + pass + def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index be68ec4217..3c09acb723 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2767,6 +2767,51 @@ class ModelTesterMixin: with torch.no_grad(): model(**inputs)[0] + def test_inputs_embeds_matches_input_ids(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if model_class.__name__ not in get_values(MODEL_MAPPING_NAMES): + continue + model = model_class(config) + model.to(torch_device) + model.eval() + + model_forward_args = inspect.signature(model.forward).parameters + if "inputs_embeds" not in model_forward_args: + self.skipTest("This model doesn't use `inputs_embeds`") + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1 + + wte = model.get_input_embeddings() + if not self.is_encoder_decoder: + input_ids = inputs["input_ids"] + # some models infer position ids/attn mask differently when input ids + # by check if pad_token let's make sure no padding is in input ids + not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1 + input_ids[input_ids == pad_token_id] = not_pad_token_id + del inputs["input_ids"] + inputs_embeds = wte(input_ids) + with torch.no_grad(): + out_ids = model(input_ids=input_ids, **inputs)[0] + out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] + else: + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + encoder_input_ids[encoder_input_ids == pad_token_id] = max(0, pad_token_id + 1) + decoder_input_ids[decoder_input_ids == pad_token_id] = max(0, pad_token_id + 1) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + inputs_embeds = wte(encoder_input_ids) + decoder_inputs_embeds = wte(decoder_input_ids) + with torch.no_grad(): + out_ids = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids, **inputs)[0] + out_embeds = model( + inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, **inputs + )[0] + self.assertTrue(torch.allclose(out_embeds, out_ids)) + @require_torch_multi_gpu def test_multi_gpu_data_parallel_forward(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()