From b4a3a647448a1d54a9d130670344643a19a87d0d Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Sun, 8 Mar 2020 16:25:03 +0100 Subject: [PATCH] fix xlnet & transfotests --- tests/test_modeling_tf_transfo_xl.py | 16 +++--------- tests/test_modeling_tf_xlnet.py | 16 +++--------- tests/test_modeling_transfo_xl.py | 17 +++--------- tests/test_modeling_xlnet.py | 39 ++++++++++------------------ 4 files changed, 24 insertions(+), 64 deletions(-) diff --git a/tests/test_modeling_tf_transfo_xl.py b/tests/test_modeling_tf_transfo_xl.py index 8b7a514a04..fbe04f4b9d 100644 --- a/tests/test_modeling_tf_transfo_xl.py +++ b/tests/test_modeling_tf_transfo_xl.py @@ -519,20 +519,10 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase): 24, 24, 0, - 29546, - 40, - 1092, - 18, - 8, - 5854, - 7, - 1143, - 2, - 7, + 33, 1, - 159, - 99, - 16, + 1857, + 2, 1, 1009, 4, diff --git a/tests/test_modeling_tf_xlnet.py b/tests/test_modeling_tf_xlnet.py index 95c72b6ea5..ed2d94b93e 100644 --- a/tests/test_modeling_tf_xlnet.py +++ b/tests/test_modeling_tf_xlnet.py @@ -760,20 +760,10 @@ class TFXLNetModelLanguageGenerationTest(unittest.TestCase): 9, 4, 3, - 1722, - 19, - 24, - 6348, - 61, - 977, - 176, - 1772, - 33, - 45, - 970, - 19, - 4185, 19, + 12943, + 4354, + 153, 27, 442, 22, diff --git a/tests/test_modeling_transfo_xl.py b/tests/test_modeling_transfo_xl.py index 94a5293cec..4dfc1db428 100644 --- a/tests/test_modeling_transfo_xl.py +++ b/tests/test_modeling_transfo_xl.py @@ -376,6 +376,7 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase): # father initially slaps him for making such an accusation , Rasputin watches as the # man is chased outside and beaten . Twenty years later , Rasputin sees a vision of # the Virgin Mary , prompting him to become a priest . Rasputin quickly becomes famous , + # with people , even a bishop , begging for his blessing . expected_output_ids = [ @@ -520,20 +521,10 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase): 24, 24, 0, - 29546, - 40, - 1092, - 18, - 8, - 5854, - 7, - 1143, - 2, - 7, + 33, 1, - 159, - 99, - 16, + 1857, + 2, 1, 1009, 4, diff --git a/tests/test_modeling_xlnet.py b/tests/test_modeling_xlnet.py index 678d73c0d7..a4bc104c18 100644 --- a/tests/test_modeling_xlnet.py +++ b/tests/test_modeling_xlnet.py @@ -119,11 +119,11 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size) perm_mask = torch.zeros( - self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device + self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device, ) perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token target_mapping = torch.zeros( - self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device + self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device, ) target_mapping[:, 0, -1] = 1.0 # predict last token @@ -212,7 +212,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): self.parent.assertEqual(len(no_mems_outputs), 1) self.parent.assertListEqual( - list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size] + list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size], ) self.parent.assertListEqual( list(list(mem.size()) for mem in result["mems_1"]), @@ -283,7 +283,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): self.parent.assertListEqual(list(result["loss_1"].size()), []) self.parent.assertListEqual( - list(result["all_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size] + list(result["all_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size], ) self.parent.assertListEqual( list(list(mem.size()) for mem in result["mems_1"]), @@ -292,7 +292,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): self.parent.assertListEqual(list(result["loss_2"].size()), []) self.parent.assertListEqual( - list(result["all_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size] + list(result["all_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size], ) self.parent.assertListEqual( list(list(mem.size()) for mem in result["mems_2"]), @@ -319,7 +319,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): model.eval() outputs = model(input_ids_1) - start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs + (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems,) = outputs outputs = model( input_ids_1, @@ -340,7 +340,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): total_loss, mems = outputs - outputs = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels) + outputs = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels,) total_loss, mems = outputs @@ -356,10 +356,10 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual( - list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top] + list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top], ) self.parent.assertListEqual( - list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top] + list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top], ) self.parent.assertListEqual( list(result["end_top_log_probs"].size()), @@ -405,7 +405,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual( - list(result["logits"].size()), [self.batch_size, self.seq_length, self.type_sequence_label_size] + list(result["logits"].size()), [self.batch_size, self.seq_length, self.type_sequence_label_size], ) self.parent.assertListEqual( list(list(mem.size()) for mem in result["mems_1"]), @@ -442,7 +442,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual( - list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size] + list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size], ) self.parent.assertListEqual( list(list(mem.size()) for mem in result["mems_1"]), @@ -859,20 +859,10 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase): 9, 4, 3, - 1722, - 19, - 24, - 6348, - 61, - 977, - 176, - 1772, - 33, - 45, - 970, - 19, - 4185, 19, + 12943, + 4354, + 153, 27, 442, 22, @@ -922,5 +912,4 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase): # the men are forced to leave the monastery. Rasputin is forced to return to output_ids = model.generate(input_ids, max_length=200, do_sample=False) - self.assertListEqual(output_ids[0].tolist(), expected_output_ids)