Fix AlignModelTest tests (#21923)
* fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -65,7 +65,7 @@ class AlignVisionModelTester:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
batch_size=13,
|
batch_size=12,
|
||||||
image_size=32,
|
image_size=32,
|
||||||
num_channels=3,
|
num_channels=3,
|
||||||
kernel_sizes=[3, 3, 5],
|
kernel_sizes=[3, 3, 5],
|
||||||
@@ -234,7 +234,7 @@ class AlignTextModelTester:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
batch_size=13,
|
batch_size=12,
|
||||||
seq_length=7,
|
seq_length=7,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
use_input_mask=True,
|
use_input_mask=True,
|
||||||
@@ -521,6 +521,15 @@ class AlignModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
loaded_model_state_dict = loaded_model.state_dict()
|
loaded_model_state_dict = loaded_model.state_dict()
|
||||||
|
|
||||||
|
non_persistent_buffers = {}
|
||||||
|
for key in loaded_model_state_dict.keys():
|
||||||
|
if key not in model_state_dict.keys():
|
||||||
|
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||||
|
|
||||||
|
loaded_model_state_dict = {
|
||||||
|
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||||
|
}
|
||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
models_equal = True
|
models_equal = True
|
||||||
|
|||||||
Reference in New Issue
Block a user