From 752ef3fd4e70869626ec70657a770a85c0ad9219 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Wed, 5 Mar 2025 11:27:01 +0100 Subject: [PATCH] guard torch version for uint16 (#36520) * u16 * style * fix --- src/transformers/modeling_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9cb3de74c2..763c8e6b6e 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -522,17 +522,19 @@ str_to_torch_dtype = { "U8": torch.uint8, "I8": torch.int8, "I16": torch.int16, - "U16": torch.uint16, "F16": torch.float16, "BF16": torch.bfloat16, "I32": torch.int32, - "U32": torch.uint32, "F32": torch.float32, "F64": torch.float64, "I64": torch.int64, - "U64": torch.uint64, } +if is_torch_greater_or_equal("2.3.0"): + str_to_torch_dtype["U16"] = torch.uint16 + str_to_torch_dtype["U32"] = torch.uint32 + str_to_torch_dtype["U64"] = torch.uint64 + def load_state_dict( checkpoint_file: Union[str, os.PathLike],