fix xlnet & transfotests

This commit is contained in:
patrickvonplaten
2020-03-08 16:25:03 +01:00
parent 66c827656f
commit b4a3a64744
4 changed files with 24 additions and 64 deletions

View File

@@ -119,11 +119,11 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
perm_mask = torch.zeros(
self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device
self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device,
)
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros(
self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device
self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device,
)
target_mapping[:, 0, -1] = 1.0 # predict last token
@@ -212,7 +212,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertEqual(len(no_mems_outputs), 1)
self.parent.assertListEqual(
list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size]
list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size],
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
@@ -283,7 +283,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss_1"].size()), [])
self.parent.assertListEqual(
list(result["all_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size]
list(result["all_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
@@ -292,7 +292,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss_2"].size()), [])
self.parent.assertListEqual(
list(result["all_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size]
list(result["all_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]),
@@ -319,7 +319,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
model.eval()
outputs = model(input_ids_1)
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
(start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems,) = outputs
outputs = model(
input_ids_1,
@@ -340,7 +340,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
total_loss, mems = outputs
outputs = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels)
outputs = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels,)
total_loss, mems = outputs
@@ -356,10 +356,10 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual(
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top]
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top],
)
self.parent.assertListEqual(
list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top]
list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top],
)
self.parent.assertListEqual(
list(result["end_top_log_probs"].size()),
@@ -405,7 +405,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual(
list(result["logits"].size()), [self.batch_size, self.seq_length, self.type_sequence_label_size]
list(result["logits"].size()), [self.batch_size, self.seq_length, self.type_sequence_label_size],
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
@@ -442,7 +442,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual(
list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size]
list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size],
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
@@ -859,20 +859,10 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
9,
4,
3,
1722,
19,
24,
6348,
61,
977,
176,
1772,
33,
45,
970,
19,
4185,
19,
12943,
4354,
153,
27,
442,
22,
@@ -922,5 +912,4 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
# the men are forced to leave the monastery. Rasputin is forced to return to
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)