[Modeling] Load FP8 safetensors such as DeepSeek (#36828)
support loading fp8 Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -532,6 +532,9 @@ str_to_torch_dtype = {
|
||||
"I64": torch.int64,
|
||||
}
|
||||
|
||||
if is_torch_greater_or_equal("2.1.0"):
|
||||
str_to_torch_dtype["F8_E4M3"] = torch.float8_e4m3fn
|
||||
|
||||
if is_torch_greater_or_equal("2.3.0"):
|
||||
str_to_torch_dtype["U16"] = torch.uint16
|
||||
str_to_torch_dtype["U32"] = torch.uint32
|
||||
@@ -562,7 +565,11 @@ def load_state_dict(
|
||||
)
|
||||
state_dict = {}
|
||||
for k in f.keys():
|
||||
dtype = str_to_torch_dtype[f.get_slice(k).get_dtype()]
|
||||
k_dtype = f.get_slice(k).get_dtype()
|
||||
if k_dtype in str_to_torch_dtype:
|
||||
dtype = str_to_torch_dtype[k_dtype]
|
||||
else:
|
||||
raise ValueError(f"Cannot load safetensors of unknown dtype {k_dtype}")
|
||||
if map_location == "meta":
|
||||
state_dict[k] = torch.empty(size=f.get_slice(k).get_shape(), dtype=dtype, device="meta")
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user