fix xlnet test

This commit is contained in:
thomwolf
2019-12-05 13:35:29 +01:00
parent 6c5297a423
commit 3268ebd229

View File

@@ -167,7 +167,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_base_model_with_att_output(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, def create_and_check_xlnet_base_model_with_att_output(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels): target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
model = XLNetModel(config) model = XLNetModel(config)
model.eval() model.eval()