fixed tests

This commit is contained in:
thomwolf
2019-07-15 12:32:19 +02:00
parent e28d8bde0d
commit f7cd7392fd
7 changed files with 63 additions and 38 deletions

View File

@@ -591,7 +591,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.transformer.wte)
def forward(self, input_ids, position_ids=None, token_type_ids=None, labels=None, past=None, head_mask=None):
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask)
transformer_outputs = self.transformer(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
past=past, head_mask=head_mask)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
@@ -709,7 +710,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
position_ids=None, past=None, head_mask=None):
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask)
transformer_outputs = self.transformer(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
past=past, head_mask=head_mask)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)