fixed tests

This commit is contained in:
thomwolf
2019-07-15 12:32:19 +02:00
parent e28d8bde0d
commit f7cd7392fd
7 changed files with 63 additions and 38 deletions

View File

@@ -768,8 +768,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
attention_mask=None, cache=None, labels=None, head_mask=None):
transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, token_type_ids=token_type_ids,
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids,
token_type_ids=token_type_ids, langs=langs,
attention_mask=attention_mask, cache=cache, head_mask=head_mask)
output = transformer_outputs[0]
outputs = self.pred_layer(output, labels)
@@ -825,8 +826,9 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
attention_mask=None, cache=None, labels=None, head_mask=None):
transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, token_type_ids=token_type_ids,
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids,
token_type_ids=token_type_ids, langs=langs,
attention_mask=attention_mask, cache=cache, head_mask=head_mask)
output = transformer_outputs[0]
logits = self.sequence_summary(output)
@@ -905,8 +907,9 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
attention_mask=None, cache=None, start_positions=None, end_positions=None,
cls_index=None, is_impossible=None, p_mask=None, head_mask=None):
transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, token_type_ids=token_type_ids,
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids,
token_type_ids=token_type_ids, langs=langs,
attention_mask=attention_mask, cache=cache, head_mask=head_mask)
output = transformer_outputs[0]