From 0a42b61edec47acb8dabb64e5f0e9e97b0746a42 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 15 Sep 2022 15:21:57 +0200 Subject: [PATCH] Fix `test_save_load` for `TFViTMAEModelTest` (#19040) * Fix test_save_load for TFViTMAEModelTest Co-authored-by: ydshieh --- tests/models/vit_mae/test_modeling_tf_vit_mae.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/models/vit_mae/test_modeling_tf_vit_mae.py b/tests/models/vit_mae/test_modeling_tf_vit_mae.py index e9db7ea6b2..f05ecaf69c 100644 --- a/tests/models/vit_mae/test_modeling_tf_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_tf_vit_mae.py @@ -375,7 +375,6 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase): # overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise # to generate masks during test - @slow def test_save_load(self): # make mask reproducible np.random.seed(2) @@ -398,9 +397,8 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase): out_2[np.isnan(out_2)] = 0 with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname, saved_model=True) - saved_model_dir = os.path.join(tmpdirname, "saved_model", "1") - model = tf.keras.models.load_model(saved_model_dir) + model.save_pretrained(tmpdirname, saved_model=False) + model = model_class.from_pretrained(tmpdirname) after_outputs = model(model_input, noise=noise) if model_class.__name__ == "TFViTMAEModel":