Fix _load_state_dict_into_meta_model with device_map=None (#36488)
* Fix _load_state_dict_into_meta_model with device_map=None * Update src/transformers/modeling_utils.py
This commit is contained in:
@@ -785,7 +785,7 @@ def _load_state_dict_into_meta_model(
|
|||||||
tensor_device = None
|
tensor_device = None
|
||||||
if device_map is not None and device_map.get("", None) is not None:
|
if device_map is not None and device_map.get("", None) is not None:
|
||||||
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
|
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
|
||||||
|
if device_map is not None:
|
||||||
device_map_regex = "|".join(sorted(device_map.keys(), reverse=True))
|
device_map_regex = "|".join(sorted(device_map.keys(), reverse=True))
|
||||||
|
|
||||||
# we need this later to initialize tensor parallelism
|
# we need this later to initialize tensor parallelism
|
||||||
|
|||||||
Reference in New Issue
Block a user