From d6d930a64b4889d715d3989691209b3f70c11b20 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 27 Mar 2025 05:47:10 -0400 Subject: [PATCH] [Modeling] Load FP8 safetensors such as DeepSeek (#36828) support loading fp8 Signed-off-by: Kyle Sayers Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/modeling_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1f7aabc6bc..abf8dec55c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -532,6 +532,9 @@ str_to_torch_dtype = { "I64": torch.int64, } +if is_torch_greater_or_equal("2.1.0"): + str_to_torch_dtype["F8_E4M3"] = torch.float8_e4m3fn + if is_torch_greater_or_equal("2.3.0"): str_to_torch_dtype["U16"] = torch.uint16 str_to_torch_dtype["U32"] = torch.uint32 @@ -562,7 +565,11 @@ def load_state_dict( ) state_dict = {} for k in f.keys(): - dtype = str_to_torch_dtype[f.get_slice(k).get_dtype()] + k_dtype = f.get_slice(k).get_dtype() + if k_dtype in str_to_torch_dtype: + dtype = str_to_torch_dtype[k_dtype] + else: + raise ValueError(f"Cannot load safetensors of unknown dtype {k_dtype}") if map_location == "meta": state_dict[k] = torch.empty(size=f.get_slice(k).get_shape(), dtype=dtype, device="meta") else: