From 44f64132a5f50726f9de4467ed745421c3b11ab3 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Wed, 10 Mar 2021 09:52:31 +0530 Subject: [PATCH] remove final_logits_bias (#10606) --- .../models/m2m_100/modeling_m2m_100.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 81fb4bd609..4505c9fc1a 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -1153,7 +1153,6 @@ class M2M100Model(M2M100PreTrainedModel): class M2M100ForConditionalGeneration(M2M100PreTrainedModel): base_model_prefix = "model" _keys_to_ignore_on_load_missing = [ - r"final_logits_bias", r"encoder\.version", r"decoder\.version", r"lm_head\.weight", @@ -1168,7 +1167,6 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): def __init__(self, config: M2M100Config): super().__init__(config) self.model = M2M100Model(config) - self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) self.init_weights() @@ -1181,18 +1179,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: new_embeddings = super().resize_token_embeddings(new_num_tokens) - self._resize_final_logits_bias(new_num_tokens) return new_embeddings - def _resize_final_logits_bias(self, new_num_tokens: int) -> None: - old_num_tokens = self.final_logits_bias.shape[-1] - if new_num_tokens <= old_num_tokens: - new_bias = self.final_logits_bias[:, :new_num_tokens] - else: - extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) - new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) - self.register_buffer("final_logits_bias", new_bias) - def get_output_embeddings(self): return self.lm_head @@ -1266,7 +1254,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): output_hidden_states=output_hidden_states, return_dict=return_dict, ) - lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias + lm_logits = self.lm_head(outputs[0]) masked_lm_loss = None if labels is not None: