No more dtype_byte_size() (#37144)
* No more dtype_byte_size() * Remove function once again * Fix rebase cruft * Trigger tests
This commit is contained in:
@@ -385,26 +385,6 @@ def get_state_dict_dtype(state_dict):
|
||||
return next(state_dict.values()).dtype
|
||||
|
||||
|
||||
def dtype_byte_size(dtype):
|
||||
"""
|
||||
Returns the size (in bytes) occupied by one parameter of type `dtype`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> dtype_byte_size(torch.float32)
|
||||
4
|
||||
```
|
||||
"""
|
||||
if dtype == torch.bool:
|
||||
return 1 / 8
|
||||
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
|
||||
if bit_search is None:
|
||||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||
bit_size = int(bit_search.groups()[0])
|
||||
return bit_size // 8
|
||||
|
||||
|
||||
def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
||||
"""
|
||||
This is the same as
|
||||
@@ -5820,7 +5800,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
|
||||
for param_name, device in accelerator_device_map.items():
|
||||
param = model.get_parameter_or_buffer(param_name)
|
||||
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
|
||||
param_byte_count = math.prod(param.shape) * dtype_byte_size(param.dtype)
|
||||
param_byte_count = math.prod(param.shape) * param.element_size()
|
||||
|
||||
if tp_plan_regex is not None:
|
||||
generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
|
||||
|
||||
Reference in New Issue
Block a user