No more dtype_byte_size() (#37144)
* No more dtype_byte_size() * Remove function once again * Fix rebase cruft * Trigger tests
This commit is contained in:
@@ -116,7 +116,6 @@ if is_torch_available():
|
||||
from transformers.modeling_utils import (
|
||||
_find_disjoint,
|
||||
_find_identical,
|
||||
dtype_byte_size,
|
||||
)
|
||||
from transformers.pytorch_utils import isin_mps_friendly
|
||||
|
||||
@@ -704,31 +703,6 @@ class ModelUtilsTest(TestCasePlus):
|
||||
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
|
||||
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
|
||||
|
||||
def test_torch_dtype_byte_sizes(self):
|
||||
torch_dtypes_and_bytes = [
|
||||
(torch.double, 8),
|
||||
(torch.float64, 8),
|
||||
(torch.float, 4),
|
||||
(torch.float32, 4),
|
||||
(torch.half, 2),
|
||||
(torch.float16, 2),
|
||||
(torch.bfloat16, 2),
|
||||
(torch.long, 8),
|
||||
(torch.int64, 8),
|
||||
(torch.int, 4),
|
||||
(torch.int32, 4),
|
||||
(torch.short, 2),
|
||||
(torch.int16, 2),
|
||||
(torch.uint8, 1),
|
||||
(torch.int8, 1),
|
||||
(torch.float8_e4m3fn, 1),
|
||||
(torch.float8_e5m2, 1),
|
||||
(torch.bool, 0.125),
|
||||
]
|
||||
|
||||
for torch_dtype, bytes_per_element in torch_dtypes_and_bytes:
|
||||
self.assertEqual(dtype_byte_size(torch_dtype), bytes_per_element)
|
||||
|
||||
def test_no_super_init_config_and_model(self):
|
||||
config = NoSuperInitConfig(attribute=32)
|
||||
model = NoSuperInitModel(config)
|
||||
|
||||
Reference in New Issue
Block a user