From 77db28dc520b74a4aaac44a8f6d6c5cfe9145a5d Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 29 Jun 2023 16:05:24 +0200 Subject: [PATCH] Update some torchscript tests after #24505 (#24566) * fix * fix --------- Co-authored-by: ydshieh --- tests/models/align/test_modeling_align.py | 11 ++++ tests/models/altclip/test_modeling_altclip.py | 11 ++++ tests/models/blip/test_modeling_blip.py | 60 +++++++++++++++++++ .../test_modeling_chinese_clip.py | 11 ++++ tests/models/clap/test_modeling_clap.py | 20 +++++++ tests/models/clip/test_modeling_clip.py | 20 +++++++ tests/models/clipseg/test_modeling_clipseg.py | 20 +++++++ tests/models/encodec/test_modeling_encodec.py | 11 ++++ tests/models/flava/test_modeling_flava.py | 20 +++++++ .../graphormer/test_modeling_graphormer.py | 11 ++++ .../models/groupvit/test_modeling_groupvit.py | 20 +++++++ .../models/imagegpt/test_modeling_imagegpt.py | 11 ++++ tests/models/owlvit/test_modeling_owlvit.py | 40 +++++++++++++ .../pix2struct/test_modeling_pix2struct.py | 20 +++++++ .../test_modeling_speech_to_text.py | 20 +++++++ tests/models/whisper/test_modeling_whisper.py | 20 +++++++ tests/models/x_clip/test_modeling_x_clip.py | 20 +++++++ 17 files changed, 346 insertions(+) diff --git a/tests/models/align/test_modeling_align.py b/tests/models/align/test_modeling_align.py index 2357c20e21..35b2aebf6c 100644 --- a/tests/models/align/test_modeling_align.py +++ b/tests/models/align/test_modeling_align.py @@ -538,6 +538,17 @@ class AlignModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/altclip/test_modeling_altclip.py b/tests/models/altclip/test_modeling_altclip.py index 28213de84d..fdae7768da 100755 --- a/tests/models/altclip/test_modeling_altclip.py +++ b/tests/models/altclip/test_modeling_altclip.py @@ -509,6 +509,17 @@ class AltCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/blip/test_modeling_blip.py b/tests/models/blip/test_modeling_blip.py index 7d9c6b5ba5..9a6e3da06c 100644 --- a/tests/models/blip/test_modeling_blip.py +++ b/tests/models/blip/test_modeling_blip.py @@ -496,8 +496,28 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] @@ -920,8 +940,28 @@ class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase): model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] @@ -1116,8 +1156,28 @@ class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase): model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/chinese_clip/test_modeling_chinese_clip.py b/tests/models/chinese_clip/test_modeling_chinese_clip.py index 57f532da86..894f1e1279 100644 --- a/tests/models/chinese_clip/test_modeling_chinese_clip.py +++ b/tests/models/chinese_clip/test_modeling_chinese_clip.py @@ -647,6 +647,17 @@ class ChineseCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/clap/test_modeling_clap.py b/tests/models/clap/test_modeling_clap.py index 56a7dfdcfc..53b4082c68 100644 --- a/tests/models/clap/test_modeling_clap.py +++ b/tests/models/clap/test_modeling_clap.py @@ -575,8 +575,28 @@ class ClapModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index 98a3f83838..4239ce5ed0 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -547,8 +547,28 @@ class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/clipseg/test_modeling_clipseg.py b/tests/models/clipseg/test_modeling_clipseg.py index 35dca117e2..e931bdc8d5 100644 --- a/tests/models/clipseg/test_modeling_clipseg.py +++ b/tests/models/clipseg/test_modeling_clipseg.py @@ -528,8 +528,28 @@ class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/encodec/test_modeling_encodec.py b/tests/models/encodec/test_modeling_encodec.py index b6291aafed..d678b707a5 100644 --- a/tests/models/encodec/test_modeling_encodec.py +++ b/tests/models/encodec/test_modeling_encodec.py @@ -229,6 +229,17 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + model_buffers = list(model.buffers()) for non_persistent_buffer in non_persistent_buffers.values(): found_buffer = False diff --git a/tests/models/flava/test_modeling_flava.py b/tests/models/flava/test_modeling_flava.py index 2544b7ee93..e69d2226c8 100644 --- a/tests/models/flava/test_modeling_flava.py +++ b/tests/models/flava/test_modeling_flava.py @@ -964,8 +964,28 @@ class FlavaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): # Non persistent buffers won't be in original state dict loaded_model_state_dict.pop("text_model.embeddings.token_type_ids", None) + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/graphormer/test_modeling_graphormer.py b/tests/models/graphormer/test_modeling_graphormer.py index e874ebf0f4..bc860e1a8a 100644 --- a/tests/models/graphormer/test_modeling_graphormer.py +++ b/tests/models/graphormer/test_modeling_graphormer.py @@ -321,6 +321,17 @@ class GraphormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + model_buffers = list(model.buffers()) for non_persistent_buffer in non_persistent_buffers.values(): found_buffer = False diff --git a/tests/models/groupvit/test_modeling_groupvit.py b/tests/models/groupvit/test_modeling_groupvit.py index f87cec426b..261841277a 100644 --- a/tests/models/groupvit/test_modeling_groupvit.py +++ b/tests/models/groupvit/test_modeling_groupvit.py @@ -630,8 +630,28 @@ class GroupViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/imagegpt/test_modeling_imagegpt.py b/tests/models/imagegpt/test_modeling_imagegpt.py index af08725ef1..19fe688bf4 100644 --- a/tests/models/imagegpt/test_modeling_imagegpt.py +++ b/tests/models/imagegpt/test_modeling_imagegpt.py @@ -500,6 +500,17 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + model_buffers = list(model.buffers()) for non_persistent_buffer in non_persistent_buffers.values(): found_buffer = False diff --git a/tests/models/owlvit/test_modeling_owlvit.py b/tests/models/owlvit/test_modeling_owlvit.py index acf078ffe8..4dbd1fb0a8 100644 --- a/tests/models/owlvit/test_modeling_owlvit.py +++ b/tests/models/owlvit/test_modeling_owlvit.py @@ -491,8 +491,28 @@ class OwlViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] @@ -670,8 +690,28 @@ class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase): model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/pix2struct/test_modeling_pix2struct.py b/tests/models/pix2struct/test_modeling_pix2struct.py index 8ec023676d..c49db3dcff 100644 --- a/tests/models/pix2struct/test_modeling_pix2struct.py +++ b/tests/models/pix2struct/test_modeling_pix2struct.py @@ -669,8 +669,28 @@ class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/speech_to_text/test_modeling_speech_to_text.py b/tests/models/speech_to_text/test_modeling_speech_to_text.py index 16ad704fd5..d86fc43a82 100644 --- a/tests/models/speech_to_text/test_modeling_speech_to_text.py +++ b/tests/models/speech_to_text/test_modeling_speech_to_text.py @@ -712,8 +712,28 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 614b29b160..c504c1005d 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -814,8 +814,28 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] diff --git a/tests/models/x_clip/test_modeling_x_clip.py b/tests/models/x_clip/test_modeling_x_clip.py index 2efece44ca..ef2d11ac62 100644 --- a/tests/models/x_clip/test_modeling_x_clip.py +++ b/tests/models/x_clip/test_modeling_x_clip.py @@ -612,8 +612,28 @@ class XCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name]