Test suite testing the tie_weights function as well as the resize_token_embeddings function.

Patched an issue relating to the tied weights I had introduced with the TorchScript addition.
Byte order mark management in TSV glue reading.
This commit is contained in:
LysandreJik
2019-07-11 22:09:16 -04:00
parent bd404735a7
commit 6c2ee16c04
5 changed files with 81 additions and 7 deletions

View File

@@ -587,7 +587,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
if self.config.torchscript:
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
else:
self.lm_head.weight = input_embeddings # Tied weights
self.lm_head = self.transformer.tokens_embed # Tied weights
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None):
"""
@@ -700,7 +700,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
if self.config.torchscript:
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
else:
self.lm_head.weight = input_embeddings # Tied weights
self.lm_head = self.transformer.tokens_embed # Tied weights
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
position_ids=None, head_mask=None):