fix xlnet & transfotests
This commit is contained in:
@@ -519,20 +519,10 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
24,
|
24,
|
||||||
24,
|
24,
|
||||||
0,
|
0,
|
||||||
29546,
|
33,
|
||||||
40,
|
|
||||||
1092,
|
|
||||||
18,
|
|
||||||
8,
|
|
||||||
5854,
|
|
||||||
7,
|
|
||||||
1143,
|
|
||||||
2,
|
|
||||||
7,
|
|
||||||
1,
|
1,
|
||||||
159,
|
1857,
|
||||||
99,
|
2,
|
||||||
16,
|
|
||||||
1,
|
1,
|
||||||
1009,
|
1009,
|
||||||
4,
|
4,
|
||||||
|
|||||||
@@ -760,20 +760,10 @@ class TFXLNetModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
9,
|
9,
|
||||||
4,
|
4,
|
||||||
3,
|
3,
|
||||||
1722,
|
|
||||||
19,
|
|
||||||
24,
|
|
||||||
6348,
|
|
||||||
61,
|
|
||||||
977,
|
|
||||||
176,
|
|
||||||
1772,
|
|
||||||
33,
|
|
||||||
45,
|
|
||||||
970,
|
|
||||||
19,
|
|
||||||
4185,
|
|
||||||
19,
|
19,
|
||||||
|
12943,
|
||||||
|
4354,
|
||||||
|
153,
|
||||||
27,
|
27,
|
||||||
442,
|
442,
|
||||||
22,
|
22,
|
||||||
|
|||||||
@@ -376,6 +376,7 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
# father initially slaps him for making such an accusation , Rasputin watches as the
|
# father initially slaps him for making such an accusation , Rasputin watches as the
|
||||||
# man is chased outside and beaten . Twenty years later , Rasputin sees a vision of
|
# man is chased outside and beaten . Twenty years later , Rasputin sees a vision of
|
||||||
# the Virgin Mary , prompting him to become a priest . Rasputin quickly becomes famous ,
|
# the Virgin Mary , prompting him to become a priest . Rasputin quickly becomes famous ,
|
||||||
|
|
||||||
# with people , even a bishop , begging for his blessing . <eod> </s> <eos>
|
# with people , even a bishop , begging for his blessing . <eod> </s> <eos>
|
||||||
|
|
||||||
expected_output_ids = [
|
expected_output_ids = [
|
||||||
@@ -520,20 +521,10 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
24,
|
24,
|
||||||
24,
|
24,
|
||||||
0,
|
0,
|
||||||
29546,
|
33,
|
||||||
40,
|
|
||||||
1092,
|
|
||||||
18,
|
|
||||||
8,
|
|
||||||
5854,
|
|
||||||
7,
|
|
||||||
1143,
|
|
||||||
2,
|
|
||||||
7,
|
|
||||||
1,
|
1,
|
||||||
159,
|
1857,
|
||||||
99,
|
2,
|
||||||
16,
|
|
||||||
1,
|
1,
|
||||||
1009,
|
1009,
|
||||||
4,
|
4,
|
||||||
|
|||||||
@@ -119,11 +119,11 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
|
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
|
||||||
perm_mask = torch.zeros(
|
perm_mask = torch.zeros(
|
||||||
self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device
|
self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device,
|
||||||
)
|
)
|
||||||
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
||||||
target_mapping = torch.zeros(
|
target_mapping = torch.zeros(
|
||||||
self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device
|
self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device,
|
||||||
)
|
)
|
||||||
target_mapping[:, 0, -1] = 1.0 # predict last token
|
target_mapping[:, 0, -1] = 1.0 # predict last token
|
||||||
|
|
||||||
@@ -212,7 +212,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
self.parent.assertEqual(len(no_mems_outputs), 1)
|
self.parent.assertEqual(len(no_mems_outputs), 1)
|
||||||
|
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||||
)
|
)
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(list(mem.size()) for mem in result["mems_1"]),
|
list(list(mem.size()) for mem in result["mems_1"]),
|
||||||
@@ -283,7 +283,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.parent.assertListEqual(list(result["loss_1"].size()), [])
|
self.parent.assertListEqual(list(result["loss_1"].size()), [])
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["all_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
list(result["all_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||||
)
|
)
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(list(mem.size()) for mem in result["mems_1"]),
|
list(list(mem.size()) for mem in result["mems_1"]),
|
||||||
@@ -292,7 +292,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.parent.assertListEqual(list(result["loss_2"].size()), [])
|
self.parent.assertListEqual(list(result["loss_2"].size()), [])
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["all_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
list(result["all_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||||
)
|
)
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(list(mem.size()) for mem in result["mems_2"]),
|
list(list(mem.size()) for mem in result["mems_2"]),
|
||||||
@@ -319,7 +319,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
outputs = model(input_ids_1)
|
outputs = model(input_ids_1)
|
||||||
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
|
(start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems,) = outputs
|
||||||
|
|
||||||
outputs = model(
|
outputs = model(
|
||||||
input_ids_1,
|
input_ids_1,
|
||||||
@@ -340,7 +340,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
total_loss, mems = outputs
|
total_loss, mems = outputs
|
||||||
|
|
||||||
outputs = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels)
|
outputs = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels,)
|
||||||
|
|
||||||
total_loss, mems = outputs
|
total_loss, mems = outputs
|
||||||
|
|
||||||
@@ -356,10 +356,10 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top]
|
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top],
|
||||||
)
|
)
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top]
|
list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top],
|
||||||
)
|
)
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["end_top_log_probs"].size()),
|
list(result["end_top_log_probs"].size()),
|
||||||
@@ -405,7 +405,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["logits"].size()), [self.batch_size, self.seq_length, self.type_sequence_label_size]
|
list(result["logits"].size()), [self.batch_size, self.seq_length, self.type_sequence_label_size],
|
||||||
)
|
)
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(list(mem.size()) for mem in result["mems_1"]),
|
list(list(mem.size()) for mem in result["mems_1"]),
|
||||||
@@ -442,7 +442,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size]
|
list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size],
|
||||||
)
|
)
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(list(mem.size()) for mem in result["mems_1"]),
|
list(list(mem.size()) for mem in result["mems_1"]),
|
||||||
@@ -859,20 +859,10 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
9,
|
9,
|
||||||
4,
|
4,
|
||||||
3,
|
3,
|
||||||
1722,
|
|
||||||
19,
|
|
||||||
24,
|
|
||||||
6348,
|
|
||||||
61,
|
|
||||||
977,
|
|
||||||
176,
|
|
||||||
1772,
|
|
||||||
33,
|
|
||||||
45,
|
|
||||||
970,
|
|
||||||
19,
|
|
||||||
4185,
|
|
||||||
19,
|
19,
|
||||||
|
12943,
|
||||||
|
4354,
|
||||||
|
153,
|
||||||
27,
|
27,
|
||||||
442,
|
442,
|
||||||
22,
|
22,
|
||||||
@@ -922,5 +912,4 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
# the men are forced to leave the monastery. Rasputin is forced to return to
|
# the men are forced to leave the monastery. Rasputin is forced to return to
|
||||||
|
|
||||||
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
|
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
|
||||||
|
|
||||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||||
|
|||||||
Reference in New Issue
Block a user