fixed tests

This commit is contained in:
thomwolf
2019-07-15 12:32:19 +02:00
parent e28d8bde0d
commit f7cd7392fd
7 changed files with 63 additions and 38 deletions

View File

@@ -582,7 +582,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self.transformer.tokens_embed)
def forward(self, input_ids, position_ids=None, token_type_ids=None, labels=None, head_mask=None):
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
transformer_outputs = self.transformer(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
head_mask=head_mask)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
@@ -693,7 +694,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
position_ids=None, head_mask=None):
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
transformer_outputs = self.transformer(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
head_mask=head_mask)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)