* fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -538,6 +538,17 @@ class AlignModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
models_equal = True
|
models_equal = True
|
||||||
for layer_name, p1 in model_state_dict.items():
|
for layer_name, p1 in model_state_dict.items():
|
||||||
p2 = loaded_model_state_dict[layer_name]
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
|
|||||||
@@ -509,6 +509,17 @@ class AltCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
models_equal = True
|
models_equal = True
|
||||||
for layer_name, p1 in model_state_dict.items():
|
for layer_name, p1 in model_state_dict.items():
|
||||||
p2 = loaded_model_state_dict[layer_name]
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
|
|||||||
@@ -496,8 +496,28 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
loaded_model_state_dict = loaded_model.state_dict()
|
loaded_model_state_dict = loaded_model.state_dict()
|
||||||
|
|
||||||
|
non_persistent_buffers = {}
|
||||||
|
for key in loaded_model_state_dict.keys():
|
||||||
|
if key not in model_state_dict.keys():
|
||||||
|
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||||
|
|
||||||
|
loaded_model_state_dict = {
|
||||||
|
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||||
|
}
|
||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
models_equal = True
|
models_equal = True
|
||||||
for layer_name, p1 in model_state_dict.items():
|
for layer_name, p1 in model_state_dict.items():
|
||||||
p2 = loaded_model_state_dict[layer_name]
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
@@ -920,8 +940,28 @@ class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
loaded_model_state_dict = loaded_model.state_dict()
|
loaded_model_state_dict = loaded_model.state_dict()
|
||||||
|
|
||||||
|
non_persistent_buffers = {}
|
||||||
|
for key in loaded_model_state_dict.keys():
|
||||||
|
if key not in model_state_dict.keys():
|
||||||
|
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||||
|
|
||||||
|
loaded_model_state_dict = {
|
||||||
|
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||||
|
}
|
||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
models_equal = True
|
models_equal = True
|
||||||
for layer_name, p1 in model_state_dict.items():
|
for layer_name, p1 in model_state_dict.items():
|
||||||
p2 = loaded_model_state_dict[layer_name]
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
@@ -1116,8 +1156,28 @@ class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
loaded_model_state_dict = loaded_model.state_dict()
|
loaded_model_state_dict = loaded_model.state_dict()
|
||||||
|
|
||||||
|
non_persistent_buffers = {}
|
||||||
|
for key in loaded_model_state_dict.keys():
|
||||||
|
if key not in model_state_dict.keys():
|
||||||
|
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||||
|
|
||||||
|
loaded_model_state_dict = {
|
||||||
|
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||||
|
}
|
||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
models_equal = True
|
models_equal = True
|
||||||
for layer_name, p1 in model_state_dict.items():
|
for layer_name, p1 in model_state_dict.items():
|
||||||
p2 = loaded_model_state_dict[layer_name]
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
|
|||||||
@@ -647,6 +647,17 @@ class ChineseCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
|||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
models_equal = True
|
models_equal = True
|
||||||
for layer_name, p1 in model_state_dict.items():
|
for layer_name, p1 in model_state_dict.items():
|
||||||
p2 = loaded_model_state_dict[layer_name]
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
|
|||||||
@@ -575,8 +575,28 @@ class ClapModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
loaded_model_state_dict = loaded_model.state_dict()
|
loaded_model_state_dict = loaded_model.state_dict()
|
||||||
|
|
||||||
|
non_persistent_buffers = {}
|
||||||
|
for key in loaded_model_state_dict.keys():
|
||||||
|
if key not in model_state_dict.keys():
|
||||||
|
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||||
|
|
||||||
|
loaded_model_state_dict = {
|
||||||
|
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||||
|
}
|
||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
models_equal = True
|
models_equal = True
|
||||||
for layer_name, p1 in model_state_dict.items():
|
for layer_name, p1 in model_state_dict.items():
|
||||||
p2 = loaded_model_state_dict[layer_name]
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
|
|||||||
@@ -547,8 +547,28 @@ class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
loaded_model_state_dict = loaded_model.state_dict()
|
loaded_model_state_dict = loaded_model.state_dict()
|
||||||
|
|
||||||
|
non_persistent_buffers = {}
|
||||||
|
for key in loaded_model_state_dict.keys():
|
||||||
|
if key not in model_state_dict.keys():
|
||||||
|
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||||
|
|
||||||
|
loaded_model_state_dict = {
|
||||||
|
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||||
|
}
|
||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
models_equal = True
|
models_equal = True
|
||||||
for layer_name, p1 in model_state_dict.items():
|
for layer_name, p1 in model_state_dict.items():
|
||||||
p2 = loaded_model_state_dict[layer_name]
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
|
|||||||
@@ -528,8 +528,28 @@ class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
loaded_model_state_dict = loaded_model.state_dict()
|
loaded_model_state_dict = loaded_model.state_dict()
|
||||||
|
|
||||||
|
non_persistent_buffers = {}
|
||||||
|
for key in loaded_model_state_dict.keys():
|
||||||
|
if key not in model_state_dict.keys():
|
||||||
|
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||||
|
|
||||||
|
loaded_model_state_dict = {
|
||||||
|
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||||
|
}
|
||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
models_equal = True
|
models_equal = True
|
||||||
for layer_name, p1 in model_state_dict.items():
|
for layer_name, p1 in model_state_dict.items():
|
||||||
p2 = loaded_model_state_dict[layer_name]
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
|
|||||||
@@ -229,6 +229,17 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
model_buffers = list(model.buffers())
|
model_buffers = list(model.buffers())
|
||||||
for non_persistent_buffer in non_persistent_buffers.values():
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
found_buffer = False
|
found_buffer = False
|
||||||
|
|||||||
@@ -964,8 +964,28 @@ class FlavaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
# Non persistent buffers won't be in original state dict
|
# Non persistent buffers won't be in original state dict
|
||||||
loaded_model_state_dict.pop("text_model.embeddings.token_type_ids", None)
|
loaded_model_state_dict.pop("text_model.embeddings.token_type_ids", None)
|
||||||
|
|
||||||
|
non_persistent_buffers = {}
|
||||||
|
for key in loaded_model_state_dict.keys():
|
||||||
|
if key not in model_state_dict.keys():
|
||||||
|
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||||
|
|
||||||
|
loaded_model_state_dict = {
|
||||||
|
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||||
|
}
|
||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
models_equal = True
|
models_equal = True
|
||||||
for layer_name, p1 in model_state_dict.items():
|
for layer_name, p1 in model_state_dict.items():
|
||||||
p2 = loaded_model_state_dict[layer_name]
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
|
|||||||
@@ -321,6 +321,17 @@ class GraphormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
model_buffers = list(model.buffers())
|
model_buffers = list(model.buffers())
|
||||||
for non_persistent_buffer in non_persistent_buffers.values():
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
found_buffer = False
|
found_buffer = False
|
||||||
|
|||||||
@@ -630,8 +630,28 @@ class GroupViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
loaded_model_state_dict = loaded_model.state_dict()
|
loaded_model_state_dict = loaded_model.state_dict()
|
||||||
|
|
||||||
|
non_persistent_buffers = {}
|
||||||
|
for key in loaded_model_state_dict.keys():
|
||||||
|
if key not in model_state_dict.keys():
|
||||||
|
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||||
|
|
||||||
|
loaded_model_state_dict = {
|
||||||
|
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||||
|
}
|
||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
models_equal = True
|
models_equal = True
|
||||||
for layer_name, p1 in model_state_dict.items():
|
for layer_name, p1 in model_state_dict.items():
|
||||||
p2 = loaded_model_state_dict[layer_name]
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
|
|||||||
@@ -500,6 +500,17 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
model_buffers = list(model.buffers())
|
model_buffers = list(model.buffers())
|
||||||
for non_persistent_buffer in non_persistent_buffers.values():
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
found_buffer = False
|
found_buffer = False
|
||||||
|
|||||||
@@ -491,8 +491,28 @@ class OwlViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
loaded_model_state_dict = loaded_model.state_dict()
|
loaded_model_state_dict = loaded_model.state_dict()
|
||||||
|
|
||||||
|
non_persistent_buffers = {}
|
||||||
|
for key in loaded_model_state_dict.keys():
|
||||||
|
if key not in model_state_dict.keys():
|
||||||
|
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||||
|
|
||||||
|
loaded_model_state_dict = {
|
||||||
|
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||||
|
}
|
||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
models_equal = True
|
models_equal = True
|
||||||
for layer_name, p1 in model_state_dict.items():
|
for layer_name, p1 in model_state_dict.items():
|
||||||
p2 = loaded_model_state_dict[layer_name]
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
@@ -670,8 +690,28 @@ class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
loaded_model_state_dict = loaded_model.state_dict()
|
loaded_model_state_dict = loaded_model.state_dict()
|
||||||
|
|
||||||
|
non_persistent_buffers = {}
|
||||||
|
for key in loaded_model_state_dict.keys():
|
||||||
|
if key not in model_state_dict.keys():
|
||||||
|
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||||
|
|
||||||
|
loaded_model_state_dict = {
|
||||||
|
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||||
|
}
|
||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
models_equal = True
|
models_equal = True
|
||||||
for layer_name, p1 in model_state_dict.items():
|
for layer_name, p1 in model_state_dict.items():
|
||||||
p2 = loaded_model_state_dict[layer_name]
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
|
|||||||
@@ -669,8 +669,28 @@ class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
loaded_model_state_dict = loaded_model.state_dict()
|
loaded_model_state_dict = loaded_model.state_dict()
|
||||||
|
|
||||||
|
non_persistent_buffers = {}
|
||||||
|
for key in loaded_model_state_dict.keys():
|
||||||
|
if key not in model_state_dict.keys():
|
||||||
|
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||||
|
|
||||||
|
loaded_model_state_dict = {
|
||||||
|
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||||
|
}
|
||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
models_equal = True
|
models_equal = True
|
||||||
for layer_name, p1 in model_state_dict.items():
|
for layer_name, p1 in model_state_dict.items():
|
||||||
p2 = loaded_model_state_dict[layer_name]
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
|
|||||||
@@ -712,8 +712,28 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
|||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
loaded_model_state_dict = loaded_model.state_dict()
|
loaded_model_state_dict = loaded_model.state_dict()
|
||||||
|
|
||||||
|
non_persistent_buffers = {}
|
||||||
|
for key in loaded_model_state_dict.keys():
|
||||||
|
if key not in model_state_dict.keys():
|
||||||
|
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||||
|
|
||||||
|
loaded_model_state_dict = {
|
||||||
|
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||||
|
}
|
||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
models_equal = True
|
models_equal = True
|
||||||
for layer_name, p1 in model_state_dict.items():
|
for layer_name, p1 in model_state_dict.items():
|
||||||
p2 = loaded_model_state_dict[layer_name]
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
|
|||||||
@@ -814,8 +814,28 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
loaded_model_state_dict = loaded_model.state_dict()
|
loaded_model_state_dict = loaded_model.state_dict()
|
||||||
|
|
||||||
|
non_persistent_buffers = {}
|
||||||
|
for key in loaded_model_state_dict.keys():
|
||||||
|
if key not in model_state_dict.keys():
|
||||||
|
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||||
|
|
||||||
|
loaded_model_state_dict = {
|
||||||
|
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||||
|
}
|
||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
models_equal = True
|
models_equal = True
|
||||||
for layer_name, p1 in model_state_dict.items():
|
for layer_name, p1 in model_state_dict.items():
|
||||||
p2 = loaded_model_state_dict[layer_name]
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
|
|||||||
@@ -612,8 +612,28 @@ class XCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
loaded_model_state_dict = loaded_model.state_dict()
|
loaded_model_state_dict = loaded_model.state_dict()
|
||||||
|
|
||||||
|
non_persistent_buffers = {}
|
||||||
|
for key in loaded_model_state_dict.keys():
|
||||||
|
if key not in model_state_dict.keys():
|
||||||
|
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||||
|
|
||||||
|
loaded_model_state_dict = {
|
||||||
|
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||||
|
}
|
||||||
|
|
||||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
models_equal = True
|
models_equal = True
|
||||||
for layer_name, p1 in model_state_dict.items():
|
for layer_name, p1 in model_state_dict.items():
|
||||||
p2 = loaded_model_state_dict[layer_name]
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
|
|||||||
Reference in New Issue
Block a user