updating tests
This commit is contained in:
@@ -430,6 +430,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
self.tokens_embed = self._get_resized_embeddings(self.tokens_embed, new_num_tokens)
|
||||
return self.tokens_embed
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
@@ -583,11 +584,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
input_embeddings = self.transformer.tokens_embed.weight
|
||||
if self.config.torchscript:
|
||||
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
||||
else:
|
||||
self.lm_head.weight = input_embeddings # Tied weights
|
||||
self._tie_or_clone_weights(self.lm_head,
|
||||
self.transformer.tokens_embed)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None):
|
||||
"""
|
||||
@@ -696,11 +694,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
input_embeddings = self.transformer.tokens_embed.weight
|
||||
if self.config.torchscript:
|
||||
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
||||
else:
|
||||
self.lm_head.weight = input_embeddings # Tied weights
|
||||
self._tie_or_clone_weights(self.lm_head,
|
||||
self.transformer.tokens_embed)
|
||||
|
||||
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None):
|
||||
|
||||
Reference in New Issue
Block a user