Protect get_default_device for torch<2.3 (#38376)
* Update modeling_utils.py * CIs
This commit is contained in:
@@ -324,7 +324,8 @@ def get_torch_context_manager_or_global_device():
|
||||
is not "cpu". This is used to infer the correct device to load the model on, in case `device_map` is not provided.
|
||||
"""
|
||||
device_in_context = torch.tensor([]).device
|
||||
default_device = torch.get_default_device()
|
||||
# `get_default_device` was only introduced in torch>=2.3 - use cpu otherwise to align the behavior
|
||||
default_device = torch.get_default_device() if is_torch_greater_or_equal("2.3") else torch.device("cpu")
|
||||
# This case means no context manager was used -> we still check if the default that was potentially set is not cpu
|
||||
if device_in_context == default_device:
|
||||
if default_device != torch.device("cpu"):
|
||||
|
||||
Reference in New Issue
Block a user