Small fix on context manager detection (#37562)

* small fixes

* Update modeling_utils.py

* test

* Update test_modeling_common.py

* Update test_modeling_timm_backbone.py

* more general

* simpler
This commit is contained in:
Cyril Vallez
2025-04-17 15:39:44 +02:00
committed by GitHub
parent c7d3cc67a1
commit 58e5e976e0
6 changed files with 36 additions and 18 deletions

View File

@@ -4167,15 +4167,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
_adapter_model_path = None
# Potentially detect context manager or global device, and use it (only if no device_map was provided)
if device_map is None:
if device_map is None and not is_deepspeed_zero3_enabled():
device_in_context = get_torch_context_manager_or_global_device()
if device_in_context == torch.device("meta"):
raise ValueError(
(
"`from_pretrained` is not compatible with a meta device context manager or `torch.set_default_device('meta')` "
"as its purpose is to load weights. If you want to initialize a model on the meta device, use the context manager "
"or global device with `from_config`, or `ModelClass(config)`"
)
# TODO Cyril: raise an error instead of the warning in v4.53 (and change the test to check for raise instead of success)
logger.warning(
"We detected that you are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`\n"
"This is an anti-pattern and will raise an Error in version v4.53\nIf you want to initialize a model on the meta device, use "
"the context manager or global device with `from_config`, or `ModelClass(config)`"
)
device_map = device_in_context
@@ -5834,6 +5833,16 @@ def expand_device_map(device_map, param_names):
return new_device_map
def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
"""Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
a proper `torch.device`.
"""
if device == "disk":
return False
else:
return torch.device(device).type not in ["meta", "cpu"]
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, factor=2):
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
@@ -5853,9 +5862,9 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
- Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
"""
# Remove disk and cpu devices, and cast to proper torch.device
# Remove disk, cpu and meta devices, and cast to proper torch.device
accelerator_device_map = {
param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"]
param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device)
}
if not len(accelerator_device_map):
return