From 75fd00fb25d7d194a02c884ff571985e30f6eadd Mon Sep 17 00:00:00 2001 From: sandip Date: Wed, 3 Feb 2021 22:09:40 +0530 Subject: [PATCH] Integration test added for TF MPnet (#9979) --- tests/test_modeling_tf_mpnet.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_modeling_tf_mpnet.py b/tests/test_modeling_tf_mpnet.py index da14679ba6..160283350b 100644 --- a/tests/test_modeling_tf_mpnet.py +++ b/tests/test_modeling_tf_mpnet.py @@ -240,3 +240,26 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase): for model_name in ["microsoft/mpnet-base"]: model = TFMPNetModel.from_pretrained(model_name) self.assertIsNotNone(model) + + +@require_tf +class TFMPNetModelIntegrationTest(unittest.TestCase): + @slow + def test_inference_masked_lm(self): + model = TFMPNetModel.from_pretrained("microsoft/mpnet-base") + input_ids = tf.constant([[0, 1, 2, 3, 4, 5]]) + output = model(input_ids)[0] + + expected_shape = [1, 6, 768] + self.assertEqual(output.shape, expected_shape) + + expected_slice = tf.constant( + [ + [ + [-0.1067172, 0.08216473, 0.0024543], + [-0.03465879, 0.8354118, -0.03252288], + [-0.06569476, -0.12424111, -0.0494436], + ] + ] + ) + tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)