fixed tests
This commit is contained in:
@@ -582,7 +582,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
self.transformer.tokens_embed)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, labels=None, head_mask=None):
|
||||
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
|
||||
transformer_outputs = self.transformer(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||
head_mask=head_mask)
|
||||
hidden_states = transformer_outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
@@ -693,7 +694,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
|
||||
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None):
|
||||
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
|
||||
transformer_outputs = self.transformer(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||
head_mask=head_mask)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user