Fix test_save_load for TFViTMAEModelTest (#19040)

* Fix test_save_load for TFViTMAEModelTest

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-09-15 15:21:57 +02:00
committed by GitHub
parent 30a28f5227
commit 0a42b61ede

View File

@@ -375,7 +375,6 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase):
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise # overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
# to generate masks during test # to generate masks during test
@slow
def test_save_load(self): def test_save_load(self):
# make mask reproducible # make mask reproducible
np.random.seed(2) np.random.seed(2)
@@ -398,9 +397,8 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase):
out_2[np.isnan(out_2)] = 0 out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=True) model.save_pretrained(tmpdirname, saved_model=False)
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1") model = model_class.from_pretrained(tmpdirname)
model = tf.keras.models.load_model(saved_model_dir)
after_outputs = model(model_input, noise=noise) after_outputs = model(model_input, noise=noise)
if model_class.__name__ == "TFViTMAEModel": if model_class.__name__ == "TFViTMAEModel":