updating tests

This commit is contained in:
thomwolf
2019-07-12 10:57:58 +02:00
parent 3fbceed8d2
commit 2918b7d2a0
14 changed files with 672 additions and 596 deletions

View File

@@ -414,6 +414,7 @@ class GPT2Model(GPT2PreTrainedModel):
def _resize_token_embeddings(self, new_num_tokens):
self.wte = self._get_resized_embeddings(self.wte, new_num_tokens)
return self.wte
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
@@ -562,11 +563,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
input_embeddings = self.transformer.wte.weight
if self.config.torchscript:
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
else:
self.lm_head.weight = input_embeddings # Tied weights
self._tie_or_clone_weights(self.lm_head,
self.transformer.wte)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None, head_mask=None):
"""
@@ -658,11 +656,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
input_embeddings = self.transformer.wte.weight
if self.config.torchscript:
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
else:
self.lm_head.weight = input_embeddings # Tied weights
self._tie_or_clone_weights(self.lm_head,
self.transformer.wte)
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
position_ids=None, past=None, head_mask=None):