[M2M100] fix positional embeddings (#10590)

* fix tests

* emb should be a parameter

* fix positional embeddings

* fix make_weights

* don't save pos embeds

* add comment to describe the clamping
This commit is contained in:
Suraj Patil
2021-03-08 16:06:19 +05:30
committed by GitHub
parent d59464db6b
commit 2a737bffef
2 changed files with 29 additions and 9 deletions

View File

@@ -121,8 +121,17 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module):
self.offset = 2
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.weights = self.get_embedding(num_positions + self.offset, embedding_dim, padding_idx)
self.register_buffer("_float_tensor", torch.FloatTensor(1))
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
if hasattr(self, "weights"):
# in forward, put the weights on correct device
emb_weights = emb_weights.to(self.weights.device)
self.weights = nn.Parameter(emb_weights)
self.weights.requires_grad = False
self.weights.detach_()
@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
@@ -142,6 +151,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module):
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
@torch.no_grad()
@@ -161,9 +171,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module):
# expand embeddings if needed
max_pos = self.padding_idx + 1 + seq_len
if max_pos > self.weights.size(0):
self.weights = self.get_embedding(max_pos, self.embedding_dim, self.padding_idx)
self.weights = self.weights.to(self._float_tensor)
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
@@ -1149,6 +1157,12 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
r"encoder\.version",
r"decoder\.version",
r"lm_head\.weight",
r"model.encoder.embed_positions.weights",
r"model.decoder.embed_positions.weights",
]
_keys_to_ignore_on_save = [
r"model.encoder.embed_positions.weights",
r"model.decoder.embed_positions.weights",
]
def __init__(self, config: M2M100Config):