Update dtype_byte_size to handle torch.float8_e4m3fn/float8_e5m2 types (#30488)
* Update modeling_utils/dtype_byte_size to handle float8 types * Add a test for dtype_byte_size * Format * Fix bool
This commit is contained in:
@@ -324,7 +324,7 @@ def dtype_byte_size(dtype):
|
|||||||
"""
|
"""
|
||||||
if dtype == torch.bool:
|
if dtype == torch.bool:
|
||||||
return 1 / 8
|
return 1 / 8
|
||||||
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
|
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
|
||||||
if bit_search is None:
|
if bit_search is None:
|
||||||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||||
bit_size = int(bit_search.groups()[0])
|
bit_size = int(bit_search.groups()[0])
|
||||||
|
|||||||
@@ -101,7 +101,12 @@ if is_torch_available():
|
|||||||
_prepare_4d_attention_mask,
|
_prepare_4d_attention_mask,
|
||||||
_prepare_4d_causal_attention_mask,
|
_prepare_4d_causal_attention_mask,
|
||||||
)
|
)
|
||||||
from transformers.modeling_utils import _find_disjoint, _find_identical, shard_checkpoint
|
from transformers.modeling_utils import (
|
||||||
|
_find_disjoint,
|
||||||
|
_find_identical,
|
||||||
|
dtype_byte_size,
|
||||||
|
shard_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
# Fake pretrained models for tests
|
# Fake pretrained models for tests
|
||||||
class BaseModel(PreTrainedModel):
|
class BaseModel(PreTrainedModel):
|
||||||
@@ -465,6 +470,31 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
|
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_torch_dtype_byte_sizes(self):
|
||||||
|
torch_dtypes_and_bytes = [
|
||||||
|
(torch.double, 8),
|
||||||
|
(torch.float64, 8),
|
||||||
|
(torch.float, 4),
|
||||||
|
(torch.float32, 4),
|
||||||
|
(torch.half, 2),
|
||||||
|
(torch.float16, 2),
|
||||||
|
(torch.bfloat16, 2),
|
||||||
|
(torch.long, 8),
|
||||||
|
(torch.int64, 8),
|
||||||
|
(torch.int, 4),
|
||||||
|
(torch.int32, 4),
|
||||||
|
(torch.short, 2),
|
||||||
|
(torch.int16, 2),
|
||||||
|
(torch.uint8, 1),
|
||||||
|
(torch.int8, 1),
|
||||||
|
(torch.float8_e4m3fn, 1),
|
||||||
|
(torch.float8_e5m2, 1),
|
||||||
|
(torch.bool, 0.125),
|
||||||
|
]
|
||||||
|
|
||||||
|
for torch_dtype, bytes_per_element in torch_dtypes_and_bytes:
|
||||||
|
self.assertEqual(dtype_byte_size(torch_dtype), bytes_per_element)
|
||||||
|
|
||||||
def test_no_super_init_config_and_model(self):
|
def test_no_super_init_config_and_model(self):
|
||||||
config = NoSuperInitConfig(attribute=32)
|
config = NoSuperInitConfig(attribute=32)
|
||||||
model = NoSuperInitModel(config)
|
model = NoSuperInitModel(config)
|
||||||
|
|||||||
Reference in New Issue
Block a user