From bcc9e93e6f585eec96444218b61b517f3f2f6314 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 14 Jun 2019 15:38:20 +0200 Subject: [PATCH] fix test --- tests/modeling_gpt2_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/modeling_gpt2_test.py b/tests/modeling_gpt2_test.py index 41cc9b8fd3..7817b98875 100644 --- a/tests/modeling_gpt2_test.py +++ b/tests/modeling_gpt2_test.py @@ -152,9 +152,10 @@ class GPT2ModelTest(unittest.TestCase): self.parent.assertListEqual( list(result["lm_logits"].size()), [self.batch_size, self.n_choices, self.seq_length, total_voc]) + self.parent.assertEqual(self.n_layer, len(result["presents"])) self.parent.assertListEqual( - list(result["presents"].size()), - [self.batch_size, self.n_choices, self.seq_length, total_voc]) + list(result["presents"][0].size()), + [2, self.batch_size * self.n_choices, self.n_head, self.seq_length, self.n_embd // self.n_head]) def check_gpt2_lm_head_loss_output(self, result): self.parent.assertListEqual(