Fix _load_state_dict_into_meta_model with device_map=None (#36488)

* Fix _load_state_dict_into_meta_model with device_map=None

* Update src/transformers/modeling_utils.py
This commit is contained in:
hlky
2025-03-02 07:33:36 +00:00
committed by GitHub
parent a40f1ac602
commit dcbdf7e962

View File

@@ -785,7 +785,7 @@ def _load_state_dict_into_meta_model(
tensor_device = None
if device_map is not None and device_map.get("", None) is not None:
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
if device_map is not None:
device_map_regex = "|".join(sorted(device_map.keys(), reverse=True))
# we need this later to initialize tensor parallelism