[Modeling] Load FP8 safetensors such as DeepSeek (#36828)

support loading fp8

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Kyle Sayers
2025-03-27 05:47:10 -04:00
committed by GitHub
parent 927ce1d39f
commit d6d930a64b

View File

@@ -532,6 +532,9 @@ str_to_torch_dtype = {
"I64": torch.int64,
}
if is_torch_greater_or_equal("2.1.0"):
str_to_torch_dtype["F8_E4M3"] = torch.float8_e4m3fn
if is_torch_greater_or_equal("2.3.0"):
str_to_torch_dtype["U16"] = torch.uint16
str_to_torch_dtype["U32"] = torch.uint32
@@ -562,7 +565,11 @@ def load_state_dict(
)
state_dict = {}
for k in f.keys():
dtype = str_to_torch_dtype[f.get_slice(k).get_dtype()]
k_dtype = f.get_slice(k).get_dtype()
if k_dtype in str_to_torch_dtype:
dtype = str_to_torch_dtype[k_dtype]
else:
raise ValueError(f"Cannot load safetensors of unknown dtype {k_dtype}")
if map_location == "meta":
state_dict[k] = torch.empty(size=f.get_slice(k).get_shape(), dtype=dtype, device="meta")
else: