From b5b76b55618aaab98064f615d9bb0e7c303dee5c Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 26 May 2025 15:00:09 +0200 Subject: [PATCH] Protect `get_default_device` for torch<2.3 (#38376) * Update modeling_utils.py * CIs --- src/transformers/modeling_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f426d10ccd..5c377acaba 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -324,7 +324,8 @@ def get_torch_context_manager_or_global_device(): is not "cpu". This is used to infer the correct device to load the model on, in case `device_map` is not provided. """ device_in_context = torch.tensor([]).device - default_device = torch.get_default_device() + # `get_default_device` was only introduced in torch>=2.3 - use cpu otherwise to align the behavior + default_device = torch.get_default_device() if is_torch_greater_or_equal("2.3") else torch.device("cpu") # This case means no context manager was used -> we still check if the default that was potentially set is not cpu if device_in_context == default_device: if default_device != torch.device("cpu"):