Added integration tests for TensorFlow implementation of the ALBERT model (#9976)
* TF Albert integration test * TF Alber integration test added
This commit is contained in:
@@ -303,3 +303,26 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
for model_name in TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = TFAlbertModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFAlbertModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
model = TFAlbertForPreTraining.from_pretrained("albert-base-v2")
|
||||
input_ids = tf.constant([[0, 1, 2, 3, 4, 5]])
|
||||
output = model(input_ids)[0]
|
||||
|
||||
expected_shape = [1, 6, 30000]
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
expected_slice = tf.constant(
|
||||
[
|
||||
[
|
||||
[4.595668, 0.74462754, -1.818147],
|
||||
[4.5954347, 0.7454184, -1.8188258],
|
||||
[4.5954905, 0.7448235, -1.8182316],
|
||||
]
|
||||
]
|
||||
)
|
||||
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)
|
||||
|
||||
Reference in New Issue
Block a user