Update dtype_byte_size to handle torch.float8_e4m3fn/float8_e5m2 types (#30488)
* Update modeling_utils/dtype_byte_size to handle float8 types * Add a test for dtype_byte_size * Format * Fix bool
This commit is contained in:
@@ -101,7 +101,12 @@ if is_torch_available():
|
||||
_prepare_4d_attention_mask,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
)
|
||||
from transformers.modeling_utils import _find_disjoint, _find_identical, shard_checkpoint
|
||||
from transformers.modeling_utils import (
|
||||
_find_disjoint,
|
||||
_find_identical,
|
||||
dtype_byte_size,
|
||||
shard_checkpoint,
|
||||
)
|
||||
|
||||
# Fake pretrained models for tests
|
||||
class BaseModel(PreTrainedModel):
|
||||
@@ -465,6 +470,31 @@ class ModelUtilsTest(TestCasePlus):
|
||||
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
|
||||
)
|
||||
|
||||
def test_torch_dtype_byte_sizes(self):
|
||||
torch_dtypes_and_bytes = [
|
||||
(torch.double, 8),
|
||||
(torch.float64, 8),
|
||||
(torch.float, 4),
|
||||
(torch.float32, 4),
|
||||
(torch.half, 2),
|
||||
(torch.float16, 2),
|
||||
(torch.bfloat16, 2),
|
||||
(torch.long, 8),
|
||||
(torch.int64, 8),
|
||||
(torch.int, 4),
|
||||
(torch.int32, 4),
|
||||
(torch.short, 2),
|
||||
(torch.int16, 2),
|
||||
(torch.uint8, 1),
|
||||
(torch.int8, 1),
|
||||
(torch.float8_e4m3fn, 1),
|
||||
(torch.float8_e5m2, 1),
|
||||
(torch.bool, 0.125),
|
||||
]
|
||||
|
||||
for torch_dtype, bytes_per_element in torch_dtypes_and_bytes:
|
||||
self.assertEqual(dtype_byte_size(torch_dtype), bytes_per_element)
|
||||
|
||||
def test_no_super_init_config_and_model(self):
|
||||
config = NoSuperInitConfig(attribute=32)
|
||||
model = NoSuperInitModel(config)
|
||||
|
||||
Reference in New Issue
Block a user