Fix test_save_load for TFViTMAEModelTest (#19040)
* Fix test_save_load for TFViTMAEModelTest Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -375,7 +375,6 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
|
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
|
||||||
# to generate masks during test
|
# to generate masks during test
|
||||||
@slow
|
|
||||||
def test_save_load(self):
|
def test_save_load(self):
|
||||||
# make mask reproducible
|
# make mask reproducible
|
||||||
np.random.seed(2)
|
np.random.seed(2)
|
||||||
@@ -398,9 +397,8 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
out_2[np.isnan(out_2)] = 0
|
out_2[np.isnan(out_2)] = 0
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
model.save_pretrained(tmpdirname, saved_model=True)
|
model.save_pretrained(tmpdirname, saved_model=False)
|
||||||
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
|
model = model_class.from_pretrained(tmpdirname)
|
||||||
model = tf.keras.models.load_model(saved_model_dir)
|
|
||||||
after_outputs = model(model_input, noise=noise)
|
after_outputs = model(model_input, noise=noise)
|
||||||
|
|
||||||
if model_class.__name__ == "TFViTMAEModel":
|
if model_class.__name__ == "TFViTMAEModel":
|
||||||
|
|||||||
Reference in New Issue
Block a user