Update config.torch_dtype correctly (#36679)

* fix

* style

* new test
This commit is contained in:
Marc Sun
2025-03-13 12:08:02 +01:00
committed by GitHub
parent c4161238bd
commit fbb18ce68b
2 changed files with 11 additions and 0 deletions

View File

@@ -1223,10 +1223,12 @@ def _get_torch_dtype(
)
elif hasattr(torch, torch_dtype):
torch_dtype = getattr(torch, torch_dtype)
config.torch_dtype = torch_dtype
for sub_config_key in config.sub_configs.keys():
sub_config = getattr(config, sub_config_key)
sub_config.torch_dtype = torch_dtype
elif isinstance(torch_dtype, torch.dtype):
config.torch_dtype = torch_dtype
for sub_config_key in config.sub_configs.keys():
sub_config = getattr(config, sub_config_key)
sub_config.torch_dtype = torch_dtype