From 2a737bffef0b1afd03750ba2ef0a342fff190446 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 8 Mar 2021 16:06:19 +0530 Subject: [PATCH] [M2M100] fix positional embeddings (#10590) * fix tests * emb should be a parameter * fix positional embeddings * fix make_weights * don't save pos embeds * add comment to describe the clamping --- .../models/m2m_100/modeling_m2m_100.py | 24 +++++++++++++++---- tests/test_modeling_m2m_100.py | 14 +++++++---- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index bb9f56a443..81fb4bd609 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -121,8 +121,17 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module): self.offset = 2 self.embedding_dim = embedding_dim self.padding_idx = padding_idx - self.weights = self.get_embedding(num_positions + self.offset, embedding_dim, padding_idx) - self.register_buffer("_float_tensor", torch.FloatTensor(1)) + self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) + + def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) + if hasattr(self, "weights"): + # in forward, put the weights on correct device + emb_weights = emb_weights.to(self.weights.device) + + self.weights = nn.Parameter(emb_weights) + self.weights.requires_grad = False + self.weights.detach_() @staticmethod def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): @@ -142,6 +151,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module): emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) if padding_idx is not None: emb[padding_idx, :] = 0 + return emb @torch.no_grad() @@ -161,9 +171,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module): # expand embeddings if needed max_pos = self.padding_idx + 1 + seq_len if max_pos > self.weights.size(0): - self.weights = self.get_embedding(max_pos, self.embedding_dim, self.padding_idx) - - self.weights = self.weights.to(self._float_tensor) + self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() @@ -1149,6 +1157,12 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): r"encoder\.version", r"decoder\.version", r"lm_head\.weight", + r"model.encoder.embed_positions.weights", + r"model.decoder.embed_positions.weights", + ] + _keys_to_ignore_on_save = [ + r"model.encoder.embed_positions.weights", + r"model.decoder.embed_positions.weights", ] def __init__(self, config: M2M100Config): diff --git a/tests/test_modeling_m2m_100.py b/tests/test_modeling_m2m_100.py index 0e02ebdc18..688403efaf 100644 --- a/tests/test_modeling_m2m_100.py +++ b/tests/test_modeling_m2m_100.py @@ -96,13 +96,19 @@ class M2M100ModelTester: def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( - 3, - ) input_ids[:, -1] = self.eos_token_id # Eos Token - decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + # we need to clamp the input ids here to avoid having pad token in between + # this is because for M2M100 the position_ids are prepared such that + # all pad tokens have pos id = 2 and rest are between 2..seq_length + # and the seq_length here is seq_length - num_pad_tokens + # but when using past, there is no way of knowing if the past input ids had + # pad tokens in them, which results in incorrect seq_lenth and which in turn results in + # position_ids being off by num_pad_tokens in past input + input_ids = input_ids.clamp(self.pad_token_id + 1) + decoder_input_ids = decoder_input_ids.clamp(self.pad_token_id + 1) + config = M2M100Config( vocab_size=self.vocab_size, d_model=self.hidden_size,