fixed tests
This commit is contained in:
@@ -768,8 +768,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||
|
||||
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
|
||||
attention_mask=None, cache=None, labels=None, head_mask=None):
|
||||
transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
|
||||
transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids,
|
||||
token_type_ids=token_type_ids, langs=langs,
|
||||
attention_mask=attention_mask, cache=cache, head_mask=head_mask)
|
||||
|
||||
output = transformer_outputs[0]
|
||||
outputs = self.pred_layer(output, labels)
|
||||
@@ -825,8 +826,9 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
||||
|
||||
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
|
||||
attention_mask=None, cache=None, labels=None, head_mask=None):
|
||||
transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
|
||||
transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids,
|
||||
token_type_ids=token_type_ids, langs=langs,
|
||||
attention_mask=attention_mask, cache=cache, head_mask=head_mask)
|
||||
|
||||
output = transformer_outputs[0]
|
||||
logits = self.sequence_summary(output)
|
||||
@@ -905,8 +907,9 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
|
||||
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
|
||||
attention_mask=None, cache=None, start_positions=None, end_positions=None,
|
||||
cls_index=None, is_impossible=None, p_mask=None, head_mask=None):
|
||||
transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
|
||||
transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids,
|
||||
token_type_ids=token_type_ids, langs=langs,
|
||||
attention_mask=attention_mask, cache=cache, head_mask=head_mask)
|
||||
|
||||
output = transformer_outputs[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user