[modeling_utils] torch_dtype/auto floating dtype fixes (#17614)
* [modeling_utils] torch_dtype/auto fixes * add test * apply suggestions * add missing fallback * Renaming things * Use for else Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
This commit is contained in:
@@ -132,7 +132,10 @@ def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUti
|
|||||||
return first_tuple[1].device
|
return first_tuple[1].device
|
||||||
|
|
||||||
|
|
||||||
def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
|
def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
|
||||||
|
"""
|
||||||
|
Returns the first parameter dtype (can be non-floating) or asserts if none were found.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
return next(parameter.parameters()).dtype
|
return next(parameter.parameters()).dtype
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
@@ -147,6 +150,58 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
|
|||||||
return first_tuple[1].dtype
|
return first_tuple[1].dtype
|
||||||
|
|
||||||
|
|
||||||
|
def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
|
||||||
|
"""
|
||||||
|
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
for t in parameter.parameters():
|
||||||
|
if t.is_floating_point():
|
||||||
|
return t.dtype
|
||||||
|
# if no floating dtype was found return whatever the first dtype is
|
||||||
|
else:
|
||||||
|
return t.dtype
|
||||||
|
|
||||||
|
except StopIteration:
|
||||||
|
# For nn.DataParallel compatibility in PyTorch 1.5
|
||||||
|
|
||||||
|
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||||
|
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||||
|
return tuples
|
||||||
|
|
||||||
|
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
||||||
|
for tuple in gen:
|
||||||
|
if tuple[1].is_floating_point():
|
||||||
|
return tuple[1].dtype
|
||||||
|
# fallback to any dtype the model has even if not floating
|
||||||
|
else:
|
||||||
|
return tuple[1].dtype
|
||||||
|
|
||||||
|
|
||||||
|
def get_state_dict_float_dtype(state_dict):
|
||||||
|
"""
|
||||||
|
Returns the first found floating dtype in `state_dict` or asserts if none were found.
|
||||||
|
"""
|
||||||
|
for t in state_dict.values():
|
||||||
|
if t.is_floating_point():
|
||||||
|
return t.dtype
|
||||||
|
|
||||||
|
raise ValueError("couldn't find any floating point dtypes in state_dict")
|
||||||
|
|
||||||
|
|
||||||
|
def get_state_dict_dtype(state_dict):
|
||||||
|
"""
|
||||||
|
Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the last dtype.
|
||||||
|
"""
|
||||||
|
for t in state_dict.values():
|
||||||
|
if t.is_floating_point():
|
||||||
|
return t.dtype
|
||||||
|
|
||||||
|
# if no floating dtype was found return whatever the first dtype is
|
||||||
|
else:
|
||||||
|
return t.dtype
|
||||||
|
|
||||||
|
|
||||||
def convert_file_size_to_int(size: Union[int, str]):
|
def convert_file_size_to_int(size: Union[int, str]):
|
||||||
"""
|
"""
|
||||||
Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
|
Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
|
||||||
@@ -2076,7 +2131,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# set dtype to instantiate the model under:
|
# set dtype to instantiate the model under:
|
||||||
# 1. If torch_dtype is not None, we use that dtype
|
# 1. If torch_dtype is not None, we use that dtype
|
||||||
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
||||||
# weights entry - we assume all weights are of the same dtype
|
# weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
|
||||||
# we also may have config.torch_dtype available, but we won't rely on it till v5
|
# we also may have config.torch_dtype available, but we won't rely on it till v5
|
||||||
dtype_orig = None
|
dtype_orig = None
|
||||||
if torch_dtype is not None:
|
if torch_dtype is not None:
|
||||||
@@ -2085,10 +2140,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if is_sharded and "dtype" in sharded_metadata:
|
if is_sharded and "dtype" in sharded_metadata:
|
||||||
torch_dtype = sharded_metadata["dtype"]
|
torch_dtype = sharded_metadata["dtype"]
|
||||||
elif not is_sharded:
|
elif not is_sharded:
|
||||||
torch_dtype = next(iter(state_dict.values())).dtype
|
torch_dtype = get_state_dict_dtype(state_dict)
|
||||||
else:
|
else:
|
||||||
one_state_dict = load_state_dict(resolved_archive_file)
|
one_state_dict = load_state_dict(resolved_archive_file)
|
||||||
torch_dtype = next(iter(one_state_dict.values())).dtype
|
torch_dtype = get_state_dict_dtype(one_state_dict)
|
||||||
del one_state_dict # free CPU memory
|
del one_state_dict # free CPU memory
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -134,6 +134,7 @@ def _config_zero_init(config):
|
|||||||
|
|
||||||
|
|
||||||
TINY_T5 = "patrickvonplaten/t5-tiny-random"
|
TINY_T5 = "patrickvonplaten/t5-tiny-random"
|
||||||
|
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@@ -2557,6 +2558,10 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
|
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
|
||||||
self.assertEqual(model.dtype, torch.float16)
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
|
# test model whose first param is not of a floating type, but int
|
||||||
|
model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto")
|
||||||
|
self.assertEqual(model.dtype, torch.float32)
|
||||||
|
|
||||||
def test_no_super_init_config_and_model(self):
|
def test_no_super_init_config_and_model(self):
|
||||||
config = NoSuperInitConfig(attribute=32)
|
config = NoSuperInitConfig(attribute=32)
|
||||||
model = NoSuperInitModel(config)
|
model = NoSuperInitModel(config)
|
||||||
|
|||||||
Reference in New Issue
Block a user