diff --git a/transformers/tests/modeling_xlnet_test.py b/transformers/tests/modeling_xlnet_test.py index 3f0ef6793c..38888d4488 100644 --- a/transformers/tests/modeling_xlnet_test.py +++ b/transformers/tests/modeling_xlnet_test.py @@ -167,7 +167,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) def create_and_check_xlnet_base_model_with_att_output(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, - target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels): + target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): model = XLNetModel(config) model.eval()