From f588cf40503ff2d3baaf21dd66144157ce4fa9cd Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Wed, 9 Feb 2022 16:48:08 +0100 Subject: [PATCH] [Flax tests/FlaxBert] make from_pretrained test faster (#15561) --- tests/test_modeling_flax_bert.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_modeling_flax_bert.py b/tests/test_modeling_flax_bert.py index 89436f854f..6b2be334c7 100644 --- a/tests/test_modeling_flax_bert.py +++ b/tests/test_modeling_flax_bert.py @@ -141,7 +141,8 @@ class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase): @slow def test_model_from_pretrained(self): - for model_class_name in self.all_model_classes: - model = model_class_name.from_pretrained("bert-base-cased", from_pt=True) - outputs = model(np.ones((1, 1))) - self.assertIsNotNone(outputs) + # Only check this for base model, not necessary for all model classes. + # This will also help speed-up tests. + model = FlaxBertModel.from_pretrained("bert-base-cased") + outputs = model(np.ones((1, 1))) + self.assertIsNotNone(outputs)