WIP modeling_test_pytorch.py
This commit is contained in:
@@ -94,11 +94,10 @@ class BertModelTest(unittest.TestCase):
|
|||||||
|
|
||||||
model = modeling.BertModel(config=config)
|
model = modeling.BertModel(config=config)
|
||||||
|
|
||||||
all_encoder_layers, pooled_output, embedding_output, sequence_output = model(input_ids, token_type_ids, input_mask)
|
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||||
|
|
||||||
outputs = {
|
outputs = {
|
||||||
"embedding_output": embedding_output,
|
"sequence_output": all_encoder_layers[-1],
|
||||||
"sequence_output": sequence_output,
|
|
||||||
"pooled_output": pooled_output,
|
"pooled_output": pooled_output,
|
||||||
"all_encoder_layers": all_encoder_layers,
|
"all_encoder_layers": all_encoder_layers,
|
||||||
}
|
}
|
||||||
@@ -106,13 +105,10 @@ class BertModelTest(unittest.TestCase):
|
|||||||
|
|
||||||
def check_output(self, result):
|
def check_output(self, result):
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
result["embedding_output"].shape,
|
list(result["sequence_output"].size()),
|
||||||
[self.batch_size, self.seq_length, self.hidden_size])
|
|
||||||
self.parent.assertListEqual(
|
|
||||||
result["sequence_output"].shape,
|
|
||||||
[self.batch_size, self.seq_length, self.hidden_size])
|
[self.batch_size, self.seq_length, self.hidden_size])
|
||||||
|
|
||||||
self.parent.assertListEqual(result["pooled_output"].shape, [self.batch_size, self.hidden_size])
|
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||||
|
|
||||||
def test_default(self):
|
def test_default(self):
|
||||||
self.run_tester(BertModelTest.BertModelTester(self))
|
self.run_tester(BertModelTest.BertModelTester(self))
|
||||||
@@ -144,6 +140,7 @@ class BertModelTest(unittest.TestCase):
|
|||||||
for _ in range(total_dims):
|
for _ in range(total_dims):
|
||||||
values.append(rng.randint(0, vocab_size - 1))
|
values.append(rng.randint(0, vocab_size - 1))
|
||||||
|
|
||||||
|
# TODO Solve : the returned tensors provoke index out of range errors when passed to the model
|
||||||
return torch.tensor(data=values, dtype=torch.int32)
|
return torch.tensor(data=values, dtype=torch.int32)
|
||||||
|
|
||||||
def assert_all_tensors_reachable(self, sess, outputs):
|
def assert_all_tensors_reachable(self, sess, outputs):
|
||||||
|
|||||||
Reference in New Issue
Block a user