diff --git a/tests/test_modeling_tf_mobilebert.py b/tests/test_modeling_tf_mobilebert.py index 35c617de3b..d98e2fd27e 100644 --- a/tests/test_modeling_tf_mobilebert.py +++ b/tests/test_modeling_tf_mobilebert.py @@ -319,3 +319,26 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase): for model_name in ["google/mobilebert-uncased"]: model = TFMobileBertModel.from_pretrained(model_name) self.assertIsNotNone(model) + + +@require_tf +class TFMobileBertModelIntegrationTest(unittest.TestCase): + @slow + def test_inference_masked_lm(self): + model = TFMobileBertForPreTraining.from_pretrained("google/mobilebert-uncased") + input_ids = tf.constant([[0, 1, 2, 3, 4, 5]]) + output = model(input_ids)[0] + + expected_shape = [1, 6, 30522] + self.assertEqual(output.shape, expected_shape) + + expected_slice = tf.constant( + [ + [ + [-4.5919547, -9.248295, -9.645256], + [-6.7306175, -6.440284, -6.6052837], + [-7.2743506, -6.7847915, -6.024673], + ] + ] + ) + tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)