[M2M100] fix positional embeddings (#10590)
* fix tests * emb should be a parameter * fix positional embeddings * fix make_weights * don't save pos embeds * add comment to describe the clamping
This commit is contained in:
@@ -121,8 +121,17 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module):
|
|||||||
self.offset = 2
|
self.offset = 2
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
self.padding_idx = padding_idx
|
self.padding_idx = padding_idx
|
||||||
self.weights = self.get_embedding(num_positions + self.offset, embedding_dim, padding_idx)
|
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
|
||||||
self.register_buffer("_float_tensor", torch.FloatTensor(1))
|
|
||||||
|
def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||||
|
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
|
||||||
|
if hasattr(self, "weights"):
|
||||||
|
# in forward, put the weights on correct device
|
||||||
|
emb_weights = emb_weights.to(self.weights.device)
|
||||||
|
|
||||||
|
self.weights = nn.Parameter(emb_weights)
|
||||||
|
self.weights.requires_grad = False
|
||||||
|
self.weights.detach_()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||||
@@ -142,6 +151,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module):
|
|||||||
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
||||||
if padding_idx is not None:
|
if padding_idx is not None:
|
||||||
emb[padding_idx, :] = 0
|
emb[padding_idx, :] = 0
|
||||||
|
|
||||||
return emb
|
return emb
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -161,9 +171,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module):
|
|||||||
# expand embeddings if needed
|
# expand embeddings if needed
|
||||||
max_pos = self.padding_idx + 1 + seq_len
|
max_pos = self.padding_idx + 1 + seq_len
|
||||||
if max_pos > self.weights.size(0):
|
if max_pos > self.weights.size(0):
|
||||||
self.weights = self.get_embedding(max_pos, self.embedding_dim, self.padding_idx)
|
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
|
||||||
|
|
||||||
self.weights = self.weights.to(self._float_tensor)
|
|
||||||
|
|
||||||
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
|
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
|
||||||
|
|
||||||
@@ -1149,6 +1157,12 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
|||||||
r"encoder\.version",
|
r"encoder\.version",
|
||||||
r"decoder\.version",
|
r"decoder\.version",
|
||||||
r"lm_head\.weight",
|
r"lm_head\.weight",
|
||||||
|
r"model.encoder.embed_positions.weights",
|
||||||
|
r"model.decoder.embed_positions.weights",
|
||||||
|
]
|
||||||
|
_keys_to_ignore_on_save = [
|
||||||
|
r"model.encoder.embed_positions.weights",
|
||||||
|
r"model.decoder.embed_positions.weights",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, config: M2M100Config):
|
def __init__(self, config: M2M100Config):
|
||||||
|
|||||||
@@ -96,13 +96,19 @@ class M2M100ModelTester:
|
|||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
|
|
||||||
3,
|
|
||||||
)
|
|
||||||
input_ids[:, -1] = self.eos_token_id # Eos Token
|
input_ids[:, -1] = self.eos_token_id # Eos Token
|
||||||
|
|
||||||
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
# we need to clamp the input ids here to avoid having pad token in between
|
||||||
|
# this is because for M2M100 the position_ids are prepared such that
|
||||||
|
# all pad tokens have pos id = 2 and rest are between 2..seq_length
|
||||||
|
# and the seq_length here is seq_length - num_pad_tokens
|
||||||
|
# but when using past, there is no way of knowing if the past input ids had
|
||||||
|
# pad tokens in them, which results in incorrect seq_lenth and which in turn results in
|
||||||
|
# position_ids being off by num_pad_tokens in past input
|
||||||
|
input_ids = input_ids.clamp(self.pad_token_id + 1)
|
||||||
|
decoder_input_ids = decoder_input_ids.clamp(self.pad_token_id + 1)
|
||||||
|
|
||||||
config = M2M100Config(
|
config = M2M100Config(
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
d_model=self.hidden_size,
|
d_model=self.hidden_size,
|
||||||
|
|||||||
Reference in New Issue
Block a user