Detect and use device context manager or global device in from_pretrained (#37216)

* Update modeling_utils.py

* improve

* Update modeling_utils.py

* Update test_modeling_common.py

* Update test_modeling_timm_backbone.py

* Update test_modeling_common.py

* Update test_modeling_common.py

* Update test_modeling_common.py

* Update test_modeling_common.py

* CIs
This commit is contained in:
Cyril Vallez
2025-04-15 09:59:20 +02:00
committed by GitHub
parent 4e63a1747c
commit c8e0e603de
3 changed files with 111 additions and 1 deletions

View File

@@ -168,6 +168,18 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
pass
@unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support")
def test_can_load_with_device_context_manager(self):
pass
@unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support")
def test_can_load_with_global_device_set(self):
pass
@unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support")
def test_cannot_load_with_meta_device_context_manager(self):
pass
@unittest.skip(reason="model weights aren't tied in TimmBackbone.")
def test_tie_model_weights(self):
pass

View File

@@ -4454,6 +4454,73 @@ class ModelTesterMixin:
),
)
@require_torch_accelerator
def test_can_load_with_device_context_manager(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# Need to specify index 0 here, as `torch_device` is simply the str of the type, e.g. "cuda"
device = torch.device(torch_device, index=0)
for model_class in self.all_model_classes:
# Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which
# is not supported for e.g. dpt_hybrid)
model = model_class(copy.deepcopy(config))
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
with device:
new_model = model_class.from_pretrained(tmpdirname)
unique_devices = {param.device for param in new_model.parameters()} | {
buffer.device for buffer in new_model.buffers()
}
self.assertEqual(
unique_devices, {device}, f"All parameters should be on {device}, but found {unique_devices}."
)
@require_torch_accelerator
def test_can_load_with_global_device_set(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# Need to specify index 0 here, as `torch_device` is simply the str of the type, e.g. "cuda"
device = torch.device(torch_device, index=0)
default_device = torch.get_default_device()
for model_class in self.all_model_classes:
# Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which
# is not supported for e.g. dpt_hybrid)
model = model_class(copy.deepcopy(config))
# set a global gpu device
torch.set_default_device(device)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = model_class.from_pretrained(tmpdirname)
unique_devices = {param.device for param in new_model.parameters()} | {
buffer.device for buffer in new_model.buffers()
}
# set back the correct device
torch.set_default_device(default_device)
self.assertEqual(
unique_devices, {device}, f"All parameters should be on {device}, but found {unique_devices}."
)
def test_cannot_load_with_meta_device_context_manager(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which
# is not supported for e.g. dpt_hybrid)
model = model_class(copy.deepcopy(config))
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# This should raise an error with meta device
with self.assertRaises(ValueError, msg="`from_pretrained` is not compatible with a meta device"):
with torch.device("meta"):
_ = model_class.from_pretrained(tmpdirname)
global_rng = random.Random()