fix xlnet test
This commit is contained in:
@@ -167,7 +167,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
|||||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||||
|
|
||||||
def create_and_check_xlnet_base_model_with_att_output(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
def create_and_check_xlnet_base_model_with_att_output(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
|
||||||
model = XLNetModel(config)
|
model = XLNetModel(config)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user