Fix layer reference loss + previous attempted fix
This commit is contained in:
@@ -762,7 +762,7 @@ class BertForPreTraining(BertPreTrainedModel):
|
|||||||
if self.config.torchscript:
|
if self.config.torchscript:
|
||||||
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
|
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
|
||||||
else:
|
else:
|
||||||
self.cls.predictions.decoder = self.bert.embeddings.word_embeddings # Tied weights
|
self.cls.predictions.decoder.weight = input_embeddings # Tied weights
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
|
||||||
next_sentence_label=None, head_mask=None):
|
next_sentence_label=None, head_mask=None):
|
||||||
@@ -868,7 +868,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
if self.config.torchscript:
|
if self.config.torchscript:
|
||||||
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
|
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
|
||||||
else:
|
else:
|
||||||
self.cls.predictions.decoder = self.bert.embeddings.word_embeddings # Tied weights
|
self.cls.predictions.decoder.weight = input_embeddings # Tied weights
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -566,7 +566,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
if self.config.torchscript:
|
if self.config.torchscript:
|
||||||
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
||||||
else:
|
else:
|
||||||
self.lm_head = self.transformer.wte # Tied weights
|
self.lm_head.weight = input_embeddings # Tied weights
|
||||||
|
|
||||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None, head_mask=None):
|
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None, head_mask=None):
|
||||||
"""
|
"""
|
||||||
@@ -662,7 +662,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
if self.config.torchscript:
|
if self.config.torchscript:
|
||||||
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
||||||
else:
|
else:
|
||||||
self.lm_head = self.transformer.wte # Tied weights
|
self.lm_head.weight = input_embeddings # Tied weights
|
||||||
|
|
||||||
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
|
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
|
||||||
position_ids=None, past=None, head_mask=None):
|
position_ids=None, past=None, head_mask=None):
|
||||||
|
|||||||
@@ -587,7 +587,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|||||||
if self.config.torchscript:
|
if self.config.torchscript:
|
||||||
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
||||||
else:
|
else:
|
||||||
self.lm_head = self.transformer.tokens_embed # Tied weights
|
self.lm_head.weight = input_embeddings # Tied weights
|
||||||
|
|
||||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None):
|
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None):
|
||||||
"""
|
"""
|
||||||
@@ -700,7 +700,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||||||
if self.config.torchscript:
|
if self.config.torchscript:
|
||||||
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
||||||
else:
|
else:
|
||||||
self.lm_head = self.transformer.tokens_embed # Tied weights
|
self.lm_head.weight = input_embeddings # Tied weights
|
||||||
|
|
||||||
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
|
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
|
||||||
position_ids=None, head_mask=None):
|
position_ids=None, head_mask=None):
|
||||||
|
|||||||
@@ -541,8 +541,8 @@ class ModelUtilsTest(unittest.TestCase):
|
|||||||
model.resize_token_embeddings(config.vocab_size + 10)
|
model.resize_token_embeddings(config.vocab_size + 10)
|
||||||
decoding.weight.data.mul_(20)
|
decoding.weight.data.mul_(20)
|
||||||
# Check that the embedding layer and decoding layer are the same in size and in value
|
# Check that the embedding layer and decoding layer are the same in size and in value
|
||||||
self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
|
self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
||||||
self.assertTrue(check_same_values(embeddings, decoding))
|
self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user