[Flax tests/FlaxBert] make from_pretrained test faster (#15561)

This commit is contained in:
Suraj Patil
2022-02-09 16:48:08 +01:00
committed by GitHub
parent 7029240927
commit f588cf4050

View File

@@ -141,7 +141,8 @@ class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("bert-base-cased", from_pt=True)
# Only check this for base model, not necessary for all model classes.
# This will also help speed-up tests.
model = FlaxBertModel.from_pretrained("bert-base-cased")
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)