fix xlnet test

This commit is contained in:
thomwolf
2019-12-05 13:35:29 +01:00
parent 6c5297a423
commit 3268ebd229

View File

@@ -167,7 +167,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_base_model_with_att_output(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
model = XLNetModel(config)
model.eval()