From 5de1517d6bf6108c7fc578d6ee544f7ce2754463 Mon Sep 17 00:00:00 2001 From: Tim Rault Date: Sat, 3 Nov 2018 22:40:50 +0100 Subject: [PATCH] WIP modeling_test_pytorch.py --- modeling_test_pytorch.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/modeling_test_pytorch.py b/modeling_test_pytorch.py index f028901276..d98a6993b5 100644 --- a/modeling_test_pytorch.py +++ b/modeling_test_pytorch.py @@ -94,11 +94,10 @@ class BertModelTest(unittest.TestCase): model = modeling.BertModel(config=config) - all_encoder_layers, pooled_output, embedding_output, sequence_output = model(input_ids, token_type_ids, input_mask) + all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) outputs = { - "embedding_output": embedding_output, - "sequence_output": sequence_output, + "sequence_output": all_encoder_layers[-1], "pooled_output": pooled_output, "all_encoder_layers": all_encoder_layers, } @@ -106,13 +105,10 @@ class BertModelTest(unittest.TestCase): def check_output(self, result): self.parent.assertListEqual( - result["embedding_output"].shape, - [self.batch_size, self.seq_length, self.hidden_size]) - self.parent.assertListEqual( - result["sequence_output"].shape, + list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]) - self.parent.assertListEqual(result["pooled_output"].shape, [self.batch_size, self.hidden_size]) + self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) def test_default(self): self.run_tester(BertModelTest.BertModelTester(self)) @@ -144,6 +140,7 @@ class BertModelTest(unittest.TestCase): for _ in range(total_dims): values.append(rng.randint(0, vocab_size - 1)) + # TODO Solve : the returned tensors provoke index out of range errors when passed to the model return torch.tensor(data=values, dtype=torch.int32) def assert_all_tensors_reachable(self, sess, outputs):