Small fix on context manager detection (#37562)
* small fixes * Update modeling_utils.py * test * Update test_modeling_common.py * Update test_modeling_timm_backbone.py * more general * simpler
This commit is contained in:
@@ -4167,15 +4167,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
_adapter_model_path = None
|
||||
|
||||
# Potentially detect context manager or global device, and use it (only if no device_map was provided)
|
||||
if device_map is None:
|
||||
if device_map is None and not is_deepspeed_zero3_enabled():
|
||||
device_in_context = get_torch_context_manager_or_global_device()
|
||||
if device_in_context == torch.device("meta"):
|
||||
raise ValueError(
|
||||
(
|
||||
"`from_pretrained` is not compatible with a meta device context manager or `torch.set_default_device('meta')` "
|
||||
"as its purpose is to load weights. If you want to initialize a model on the meta device, use the context manager "
|
||||
"or global device with `from_config`, or `ModelClass(config)`"
|
||||
)
|
||||
# TODO Cyril: raise an error instead of the warning in v4.53 (and change the test to check for raise instead of success)
|
||||
logger.warning(
|
||||
"We detected that you are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`\n"
|
||||
"This is an anti-pattern and will raise an Error in version v4.53\nIf you want to initialize a model on the meta device, use "
|
||||
"the context manager or global device with `from_config`, or `ModelClass(config)`"
|
||||
)
|
||||
device_map = device_in_context
|
||||
|
||||
@@ -5834,6 +5833,16 @@ def expand_device_map(device_map, param_names):
|
||||
return new_device_map
|
||||
|
||||
|
||||
def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
|
||||
"""Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
|
||||
a proper `torch.device`.
|
||||
"""
|
||||
if device == "disk":
|
||||
return False
|
||||
else:
|
||||
return torch.device(device).type not in ["meta", "cpu"]
|
||||
|
||||
|
||||
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, factor=2):
|
||||
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
||||
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
|
||||
@@ -5853,9 +5862,9 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
|
||||
- Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
|
||||
However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
|
||||
"""
|
||||
# Remove disk and cpu devices, and cast to proper torch.device
|
||||
# Remove disk, cpu and meta devices, and cast to proper torch.device
|
||||
accelerator_device_map = {
|
||||
param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"]
|
||||
param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device)
|
||||
}
|
||||
if not len(accelerator_device_map):
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user