remove final_logits_bias (#10606)
This commit is contained in:
@@ -1153,7 +1153,6 @@ class M2M100Model(M2M100PreTrainedModel):
|
|||||||
class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
_keys_to_ignore_on_load_missing = [
|
_keys_to_ignore_on_load_missing = [
|
||||||
r"final_logits_bias",
|
|
||||||
r"encoder\.version",
|
r"encoder\.version",
|
||||||
r"decoder\.version",
|
r"decoder\.version",
|
||||||
r"lm_head\.weight",
|
r"lm_head\.weight",
|
||||||
@@ -1168,7 +1167,6 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
|||||||
def __init__(self, config: M2M100Config):
|
def __init__(self, config: M2M100Config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.model = M2M100Model(config)
|
self.model = M2M100Model(config)
|
||||||
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
|
|
||||||
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
|
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
|
||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
@@ -1181,18 +1179,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
|||||||
|
|
||||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||||
self._resize_final_logits_bias(new_num_tokens)
|
|
||||||
return new_embeddings
|
return new_embeddings
|
||||||
|
|
||||||
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
|
|
||||||
old_num_tokens = self.final_logits_bias.shape[-1]
|
|
||||||
if new_num_tokens <= old_num_tokens:
|
|
||||||
new_bias = self.final_logits_bias[:, :new_num_tokens]
|
|
||||||
else:
|
|
||||||
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
|
|
||||||
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
|
||||||
self.register_buffer("final_logits_bias", new_bias)
|
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
@@ -1266,7 +1254,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
|
lm_logits = self.lm_head(outputs[0])
|
||||||
|
|
||||||
masked_lm_loss = None
|
masked_lm_loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user