Improve performance of load_state_dict (#37902)

Improve performance of load_state_dict
This commit is contained in:
woctordho
2025-05-01 22:35:17 +08:00
committed by GitHub
parent 410aa01901
commit ee25d57ed1

View File

@@ -507,13 +507,14 @@ def load_state_dict(
) )
state_dict = {} state_dict = {}
for k in f.keys(): for k in f.keys():
k_dtype = f.get_slice(k).get_dtype()
if k_dtype in str_to_torch_dtype:
dtype = str_to_torch_dtype[k_dtype]
else:
raise ValueError(f"Cannot load safetensors of unknown dtype {k_dtype}")
if map_location == "meta": if map_location == "meta":
state_dict[k] = torch.empty(size=f.get_slice(k).get_shape(), dtype=dtype, device="meta") _slice = f.get_slice(k)
k_dtype = _slice.get_dtype()
if k_dtype in str_to_torch_dtype:
dtype = str_to_torch_dtype[k_dtype]
else:
raise ValueError(f"Cannot load safetensors of unknown dtype {k_dtype}")
state_dict[k] = torch.empty(size=_slice.get_shape(), dtype=dtype, device="meta")
else: else:
state_dict[k] = f.get_tensor(k) state_dict[k] = f.get_tensor(k)
return state_dict return state_dict