Fix test_save_load for TFViTMAEModelTest (#19040)
* Fix test_save_load for TFViTMAEModelTest Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -375,7 +375,6 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
|
||||
# to generate masks during test
|
||||
@slow
|
||||
def test_save_load(self):
|
||||
# make mask reproducible
|
||||
np.random.seed(2)
|
||||
@@ -398,9 +397,8 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, saved_model=True)
|
||||
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
|
||||
model = tf.keras.models.load_model(saved_model_dir)
|
||||
model.save_pretrained(tmpdirname, saved_model=False)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
after_outputs = model(model_input, noise=noise)
|
||||
|
||||
if model_class.__name__ == "TFViTMAEModel":
|
||||
|
||||
Reference in New Issue
Block a user