From 1486205d230dd83f6b3f690b9d1cc14483adac09 Mon Sep 17 00:00:00 2001 From: sandip Date: Wed, 3 Feb 2021 20:21:00 +0530 Subject: [PATCH] TF DistilBERT integration tests (#9975) * TF DistilBERT integration test * Update test_modeling_tf_distilbert.py --- tests/test_modeling_tf_distilbert.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_modeling_tf_distilbert.py b/tests/test_modeling_tf_distilbert.py index 3c1b755ccc..a10683f9a0 100644 --- a/tests/test_modeling_tf_distilbert.py +++ b/tests/test_modeling_tf_distilbert.py @@ -221,3 +221,26 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase): for model_name in list(TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]): model = TFDistilBertModel.from_pretrained(model_name) self.assertIsNotNone(model) + + +@require_tf +class TFDistilBertModelIntegrationTest(unittest.TestCase): + @slow + def test_inference_masked_lm(self): + model = TFDistilBertModel.from_pretrained("distilbert-base-uncased") + input_ids = tf.constant([[0, 1, 2, 3, 4, 5]]) + output = model(input_ids)[0] + + expected_shape = [1, 6, 768] + self.assertEqual(output.shape, expected_shape) + + expected_slice = tf.constant( + [ + [ + [0.19261885, -0.13732955, 0.4119799], + [0.22150156, -0.07422661, 0.39037204], + [0.22756018, -0.0896414, 0.3701467], + ] + ] + ) + tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)