Detect and use device context manager or global device in from_pretrained (#37216)

* Update modeling_utils.py

* improve

* Update modeling_utils.py

* Update test_modeling_common.py

* Update test_modeling_timm_backbone.py

* Update test_modeling_common.py

* Update test_modeling_common.py

* Update test_modeling_common.py

* Update test_modeling_common.py

* CIs
This commit is contained in:
Cyril Vallez
2025-04-15 09:59:20 +02:00
committed by GitHub
parent 4e63a1747c
commit c8e0e603de
3 changed files with 111 additions and 1 deletions

View File

@@ -287,6 +287,21 @@ def restore_default_torch_dtype(func):
return _wrapper
def get_torch_context_manager_or_global_device():
"""
Test if a device context manager is currently in use, or if it is not the case, check if the default device
is not "cpu". This is used to infer the correct device to load the model on, in case `device_map` is not provided.
"""
device_in_context = torch.tensor([]).device
default_device = torch.get_default_device()
# This case means no context manager was used -> we still check if the default that was potentially set is not cpu
if device_in_context == default_device:
if default_device != torch.device("cpu"):
return default_device
return None
return device_in_context
def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
try:
return next(parameter.parameters()).device
@@ -4153,6 +4168,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
else:
_adapter_model_path = None
# Potentially detect context manager or global device, and use it (only if no device_map was provided)
if device_map is None:
device_in_context = get_torch_context_manager_or_global_device()
if device_in_context == torch.device("meta"):
raise ValueError(
(
"`from_pretrained` is not compatible with a meta device context manager or `torch.set_default_device('meta')` "
"as its purpose is to load weights. If you want to initialize a model on the meta device, use the context manager "
"or global device with `from_config`, or `ModelClass(config)`"
)
)
device_map = device_in_context
# change device_map into a map if we passed an int, a str or a torch.device
if isinstance(device_map, torch.device):
device_map = {"": device_map}
@@ -4177,7 +4205,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
if not is_accelerate_available():
raise ValueError(
"Using a `device_map` or `tp_plan` requires `accelerate`. You can install it with `pip install accelerate`"
(
"Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` "
"requires `accelerate`. You can install it with `pip install accelerate`"
)
)
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.

View File

@@ -168,6 +168,18 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
pass
@unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support")
def test_can_load_with_device_context_manager(self):
pass
@unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support")
def test_can_load_with_global_device_set(self):
pass
@unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support")
def test_cannot_load_with_meta_device_context_manager(self):
pass
@unittest.skip(reason="model weights aren't tied in TimmBackbone.")
def test_tie_model_weights(self):
pass

View File

@@ -4454,6 +4454,73 @@ class ModelTesterMixin:
),
)
@require_torch_accelerator
def test_can_load_with_device_context_manager(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# Need to specify index 0 here, as `torch_device` is simply the str of the type, e.g. "cuda"
device = torch.device(torch_device, index=0)
for model_class in self.all_model_classes:
# Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which
# is not supported for e.g. dpt_hybrid)
model = model_class(copy.deepcopy(config))
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
with device:
new_model = model_class.from_pretrained(tmpdirname)
unique_devices = {param.device for param in new_model.parameters()} | {
buffer.device for buffer in new_model.buffers()
}
self.assertEqual(
unique_devices, {device}, f"All parameters should be on {device}, but found {unique_devices}."
)
@require_torch_accelerator
def test_can_load_with_global_device_set(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# Need to specify index 0 here, as `torch_device` is simply the str of the type, e.g. "cuda"
device = torch.device(torch_device, index=0)
default_device = torch.get_default_device()
for model_class in self.all_model_classes:
# Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which
# is not supported for e.g. dpt_hybrid)
model = model_class(copy.deepcopy(config))
# set a global gpu device
torch.set_default_device(device)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = model_class.from_pretrained(tmpdirname)
unique_devices = {param.device for param in new_model.parameters()} | {
buffer.device for buffer in new_model.buffers()
}
# set back the correct device
torch.set_default_device(default_device)
self.assertEqual(
unique_devices, {device}, f"All parameters should be on {device}, but found {unique_devices}."
)
def test_cannot_load_with_meta_device_context_manager(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which
# is not supported for e.g. dpt_hybrid)
model = model_class(copy.deepcopy(config))
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# This should raise an error with meta device
with self.assertRaises(ValueError, msg="`from_pretrained` is not compatible with a meta device"):
with torch.device("meta"):
_ = model_class.from_pretrained(tmpdirname)
global_rng = random.Random()