From 20081c743ee2ce31d178f2182c7466c3313adcd2 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 26 Apr 2024 06:26:43 -0400 Subject: [PATCH] Update `dtype_byte_size` to handle torch.float8_e4m3fn/float8_e5m2 types (#30488) * Update modeling_utils/dtype_byte_size to handle float8 types * Add a test for dtype_byte_size * Format * Fix bool --- src/transformers/modeling_utils.py | 2 +- tests/test_modeling_utils.py | 32 +++++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a004375b5e..1ed8040f88 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -324,7 +324,7 @@ def dtype_byte_size(dtype): """ if dtype == torch.bool: return 1 / 8 - bit_search = re.search(r"[^\d](\d+)$", str(dtype)) + bit_search = re.search(r"[^\d](\d+)_?", str(dtype)) if bit_search is None: raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") bit_size = int(bit_search.groups()[0]) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index ba0bf8e6b2..16d8e9e129 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -101,7 +101,12 @@ if is_torch_available(): _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, ) - from transformers.modeling_utils import _find_disjoint, _find_identical, shard_checkpoint + from transformers.modeling_utils import ( + _find_disjoint, + _find_identical, + dtype_byte_size, + shard_checkpoint, + ) # Fake pretrained models for tests class BaseModel(PreTrainedModel): @@ -465,6 +470,31 @@ class ModelUtilsTest(TestCasePlus): module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] ) + def test_torch_dtype_byte_sizes(self): + torch_dtypes_and_bytes = [ + (torch.double, 8), + (torch.float64, 8), + (torch.float, 4), + (torch.float32, 4), + (torch.half, 2), + (torch.float16, 2), + (torch.bfloat16, 2), + (torch.long, 8), + (torch.int64, 8), + (torch.int, 4), + (torch.int32, 4), + (torch.short, 2), + (torch.int16, 2), + (torch.uint8, 1), + (torch.int8, 1), + (torch.float8_e4m3fn, 1), + (torch.float8_e5m2, 1), + (torch.bool, 0.125), + ] + + for torch_dtype, bytes_per_element in torch_dtypes_and_bytes: + self.assertEqual(dtype_byte_size(torch_dtype), bytes_per_element) + def test_no_super_init_config_and_model(self): config = NoSuperInitConfig(attribute=32) model = NoSuperInitModel(config)