Improve performance of load_state_dict (#37902)
Improve performance of load_state_dict
This commit is contained in:
@@ -507,13 +507,14 @@ def load_state_dict(
|
|||||||
)
|
)
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
for k in f.keys():
|
for k in f.keys():
|
||||||
k_dtype = f.get_slice(k).get_dtype()
|
|
||||||
if k_dtype in str_to_torch_dtype:
|
|
||||||
dtype = str_to_torch_dtype[k_dtype]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Cannot load safetensors of unknown dtype {k_dtype}")
|
|
||||||
if map_location == "meta":
|
if map_location == "meta":
|
||||||
state_dict[k] = torch.empty(size=f.get_slice(k).get_shape(), dtype=dtype, device="meta")
|
_slice = f.get_slice(k)
|
||||||
|
k_dtype = _slice.get_dtype()
|
||||||
|
if k_dtype in str_to_torch_dtype:
|
||||||
|
dtype = str_to_torch_dtype[k_dtype]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Cannot load safetensors of unknown dtype {k_dtype}")
|
||||||
|
state_dict[k] = torch.empty(size=_slice.get_shape(), dtype=dtype, device="meta")
|
||||||
else:
|
else:
|
||||||
state_dict[k] = f.get_tensor(k)
|
state_dict[k] = f.get_tensor(k)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|||||||
Reference in New Issue
Block a user