Small fix on context manager detection (#37562)

* small fixes

* Update modeling_utils.py

* test

* Update test_modeling_common.py

* Update test_modeling_timm_backbone.py

* more general

* simpler
This commit is contained in:
Cyril Vallez
2025-04-17 15:39:44 +02:00
committed by GitHub
parent c7d3cc67a1
commit 58e5e976e0
6 changed files with 36 additions and 18 deletions

View File

@@ -4590,7 +4590,7 @@ class ModelTesterMixin:
unique_devices, {device}, f"All parameters should be on {device}, but found {unique_devices}."
)
def test_cannot_load_with_meta_device_context_manager(self):
def test_can_load_with_meta_device_context_manager(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which
@@ -4600,10 +4600,17 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# This should raise an error with meta device
with self.assertRaises(ValueError, msg="`from_pretrained` is not compatible with a meta device"):
with torch.device("meta"):
_ = model_class.from_pretrained(tmpdirname)
with torch.device("meta"):
new_model = model_class.from_pretrained(tmpdirname)
unique_devices = {param.device for param in new_model.parameters()} | {
buffer.device for buffer in new_model.buffers()
}
self.assertEqual(
unique_devices,
{torch.device("meta")},
f"All parameters should be on meta device, but found {unique_devices}.",
)
global_rng = random.Random()