@@ -1223,10 +1223,12 @@ def _get_torch_dtype(
|
||||
)
|
||||
elif hasattr(torch, torch_dtype):
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
config.torch_dtype = torch_dtype
|
||||
for sub_config_key in config.sub_configs.keys():
|
||||
sub_config = getattr(config, sub_config_key)
|
||||
sub_config.torch_dtype = torch_dtype
|
||||
elif isinstance(torch_dtype, torch.dtype):
|
||||
config.torch_dtype = torch_dtype
|
||||
for sub_config_key in config.sub_configs.keys():
|
||||
sub_config = getattr(config, sub_config_key)
|
||||
sub_config.torch_dtype = torch_dtype
|
||||
|
||||
Reference in New Issue
Block a user