Fix _load_state_dict_into_meta_model with device_map=None (#36488)
* Fix _load_state_dict_into_meta_model with device_map=None * Update src/transformers/modeling_utils.py
This commit is contained in:
@@ -785,7 +785,7 @@ def _load_state_dict_into_meta_model(
|
||||
tensor_device = None
|
||||
if device_map is not None and device_map.get("", None) is not None:
|
||||
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
|
||||
|
||||
if device_map is not None:
|
||||
device_map_regex = "|".join(sorted(device_map.keys(), reverse=True))
|
||||
|
||||
# we need this later to initialize tensor parallelism
|
||||
|
||||
Reference in New Issue
Block a user