A bit of cleaning 🧹🧹 (#37215)

* cleaning

* CIs
This commit is contained in:
Cyril Vallez
2025-04-08 14:33:58 +02:00
committed by GitHub
parent 1e6b546ea6
commit cdfb018d03
2 changed files with 0 additions and 101 deletions

View File

@@ -298,24 +298,6 @@ def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
return first_tuple[1].device
def get_first_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
"""
Returns the first parameter dtype (can be non-floating) or asserts if none were found.
"""
try:
return next(parameter.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
"""
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
@@ -365,17 +347,6 @@ def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
return last_dtype
def get_state_dict_float_dtype(state_dict):
"""
Returns the first found floating dtype in `state_dict` or asserts if none were found.
"""
for t in state_dict.values():
if t.is_floating_point():
return t.dtype
raise ValueError("couldn't find any floating point dtypes in state_dict")
def get_state_dict_dtype(state_dict):
"""
Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the first dtype.