Fix: dtype cannot be str (#36262)
* fix * this wan't supposed to be here, revert * refine tests a bit more
This commit is contained in:
committed by
GitHub
parent
3f9ff19b4e
commit
523f6e743c
@@ -1252,13 +1252,13 @@ def _get_torch_dtype(
|
||||
for key, curr_dtype in torch_dtype.items():
|
||||
if hasattr(config, key):
|
||||
value = getattr(config, key)
|
||||
curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
|
||||
value.torch_dtype = curr_dtype
|
||||
# main torch dtype for modules that aren't part of any sub-config
|
||||
torch_dtype = torch_dtype.get("")
|
||||
torch_dtype = torch_dtype if not isinstance(torch_dtype, str) else getattr(torch, torch_dtype)
|
||||
config.torch_dtype = torch_dtype
|
||||
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
elif torch_dtype is None:
|
||||
if torch_dtype is None:
|
||||
torch_dtype = torch.float32
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -1269,7 +1269,7 @@ def _get_torch_dtype(
|
||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||
else:
|
||||
# set fp32 as the default dtype for BC
|
||||
default_dtype = str(torch.get_default_dtype()).split(".")[-1]
|
||||
default_dtype = torch.get_default_dtype()
|
||||
config.torch_dtype = default_dtype
|
||||
for key in config.sub_configs.keys():
|
||||
value = getattr(config, key)
|
||||
|
||||
Reference in New Issue
Block a user