From d4306daea1f68d8e854b7b3b127878a5fbd53489 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 3 Mar 2023 14:47:09 +0100 Subject: [PATCH] Fix `AlignModelTest` tests (#21923) * fix * fix --------- Co-authored-by: ydshieh --- tests/models/align/test_modeling_align.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/models/align/test_modeling_align.py b/tests/models/align/test_modeling_align.py index 5376f8d08d..f2b1b1efda 100644 --- a/tests/models/align/test_modeling_align.py +++ b/tests/models/align/test_modeling_align.py @@ -65,7 +65,7 @@ class AlignVisionModelTester: def __init__( self, parent, - batch_size=13, + batch_size=12, image_size=32, num_channels=3, kernel_sizes=[3, 3, 5], @@ -234,7 +234,7 @@ class AlignTextModelTester: def __init__( self, parent, - batch_size=13, + batch_size=12, seq_length=7, is_training=True, use_input_mask=True, @@ -521,6 +521,15 @@ class AlignModelTest(ModelTesterMixin, unittest.TestCase): model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) models_equal = True