Protect get_default_device for torch<2.3 (#38376)

* Update modeling_utils.py

* CIs
This commit is contained in:
Cyril Vallez
2025-05-26 15:00:09 +02:00
committed by GitHub
parent bff32678cc
commit b5b76b5561

View File

@@ -324,7 +324,8 @@ def get_torch_context_manager_or_global_device():
is not "cpu". This is used to infer the correct device to load the model on, in case `device_map` is not provided.
"""
device_in_context = torch.tensor([]).device
default_device = torch.get_default_device()
# `get_default_device` was only introduced in torch>=2.3 - use cpu otherwise to align the behavior
default_device = torch.get_default_device() if is_torch_greater_or_equal("2.3") else torch.device("cpu")
# This case means no context manager was used -> we still check if the default that was potentially set is not cpu
if device_in_context == default_device:
if default_device != torch.device("cpu"):