Fix layer reference loss + previous attempted fix
This commit is contained in:
@@ -762,7 +762,7 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
if self.config.torchscript:
|
||||
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
|
||||
else:
|
||||
self.cls.predictions.decoder = self.bert.embeddings.word_embeddings # Tied weights
|
||||
self.cls.predictions.decoder.weight = input_embeddings # Tied weights
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
|
||||
next_sentence_label=None, head_mask=None):
|
||||
@@ -868,7 +868,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
if self.config.torchscript:
|
||||
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
|
||||
else:
|
||||
self.cls.predictions.decoder = self.bert.embeddings.word_embeddings # Tied weights
|
||||
self.cls.predictions.decoder.weight = input_embeddings # Tied weights
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user