Detect and use device context manager or global device in from_pretrained (#37216)

* Update modeling_utils.py

* improve

* Update modeling_utils.py

* Update test_modeling_common.py

* Update test_modeling_timm_backbone.py

* Update test_modeling_common.py

* Update test_modeling_common.py

* Update test_modeling_common.py

* Update test_modeling_common.py

* CIs
This commit is contained in:
Cyril Vallez
2025-04-15 09:59:20 +02:00
committed by GitHub
parent 4e63a1747c
commit c8e0e603de
3 changed files with 111 additions and 1 deletions

View File

@@ -287,6 +287,21 @@ def restore_default_torch_dtype(func):
return _wrapper
def get_torch_context_manager_or_global_device():
"""
Test if a device context manager is currently in use, or if it is not the case, check if the default device
is not "cpu". This is used to infer the correct device to load the model on, in case `device_map` is not provided.
"""
device_in_context = torch.tensor([]).device
default_device = torch.get_default_device()
# This case means no context manager was used -> we still check if the default that was potentially set is not cpu
if device_in_context == default_device:
if default_device != torch.device("cpu"):
return default_device
return None
return device_in_context
def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
try:
return next(parameter.parameters()).device
@@ -4153,6 +4168,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
else:
_adapter_model_path = None
# Potentially detect context manager or global device, and use it (only if no device_map was provided)
if device_map is None:
device_in_context = get_torch_context_manager_or_global_device()
if device_in_context == torch.device("meta"):
raise ValueError(
(
"`from_pretrained` is not compatible with a meta device context manager or `torch.set_default_device('meta')` "
"as its purpose is to load weights. If you want to initialize a model on the meta device, use the context manager "
"or global device with `from_config`, or `ModelClass(config)`"
)
)
device_map = device_in_context
# change device_map into a map if we passed an int, a str or a torch.device
if isinstance(device_map, torch.device):
device_map = {"": device_map}
@@ -4177,7 +4205,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
if not is_accelerate_available():
raise ValueError(
"Using a `device_map` or `tp_plan` requires `accelerate`. You can install it with `pip install accelerate`"
(
"Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` "
"requires `accelerate`. You can install it with `pip install accelerate`"
)
)
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.