fix xlnet & transfotests
This commit is contained in:
@@ -119,11 +119,11 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
|
||||
perm_mask = torch.zeros(
|
||||
self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device
|
||||
self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device,
|
||||
)
|
||||
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
||||
target_mapping = torch.zeros(
|
||||
self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device
|
||||
self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device,
|
||||
)
|
||||
target_mapping[:, 0, -1] = 1.0 # predict last token
|
||||
|
||||
@@ -212,7 +212,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.parent.assertEqual(len(no_mems_outputs), 1)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
@@ -283,7 +283,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.parent.assertListEqual(list(result["loss_1"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["all_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
list(result["all_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
@@ -292,7 +292,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.parent.assertListEqual(list(result["loss_2"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["all_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
list(result["all_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2"]),
|
||||
@@ -319,7 +319,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model.eval()
|
||||
|
||||
outputs = model(input_ids_1)
|
||||
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
|
||||
(start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems,) = outputs
|
||||
|
||||
outputs = model(
|
||||
input_ids_1,
|
||||
@@ -340,7 +340,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
total_loss, mems = outputs
|
||||
|
||||
outputs = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels)
|
||||
outputs = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels,)
|
||||
|
||||
total_loss, mems = outputs
|
||||
|
||||
@@ -356,10 +356,10 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top]
|
||||
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top]
|
||||
list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_top_log_probs"].size()),
|
||||
@@ -405,7 +405,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()), [self.batch_size, self.seq_length, self.type_sequence_label_size]
|
||||
list(result["logits"].size()), [self.batch_size, self.seq_length, self.type_sequence_label_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
@@ -442,7 +442,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size]
|
||||
list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
@@ -859,20 +859,10 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
||||
9,
|
||||
4,
|
||||
3,
|
||||
1722,
|
||||
19,
|
||||
24,
|
||||
6348,
|
||||
61,
|
||||
977,
|
||||
176,
|
||||
1772,
|
||||
33,
|
||||
45,
|
||||
970,
|
||||
19,
|
||||
4185,
|
||||
19,
|
||||
12943,
|
||||
4354,
|
||||
153,
|
||||
27,
|
||||
442,
|
||||
22,
|
||||
@@ -922,5 +912,4 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
||||
# the men are forced to leave the monastery. Rasputin is forced to return to
|
||||
|
||||
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
|
||||
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
Reference in New Issue
Block a user