Fix: dtype cannot be str (#36262)

* fix

* this wan't supposed to be here, revert

* refine tests a bit more
This commit is contained in:
Raushan Turganbay
2025-03-21 13:27:47 +01:00
committed by GitHub
parent 3f9ff19b4e
commit 523f6e743c
2 changed files with 18 additions and 4 deletions

View File

@@ -1252,13 +1252,13 @@ def _get_torch_dtype(
for key, curr_dtype in torch_dtype.items(): for key, curr_dtype in torch_dtype.items():
if hasattr(config, key): if hasattr(config, key):
value = getattr(config, key) value = getattr(config, key)
curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
value.torch_dtype = curr_dtype value.torch_dtype = curr_dtype
# main torch dtype for modules that aren't part of any sub-config # main torch dtype for modules that aren't part of any sub-config
torch_dtype = torch_dtype.get("") torch_dtype = torch_dtype.get("")
torch_dtype = torch_dtype if not isinstance(torch_dtype, str) else getattr(torch, torch_dtype)
config.torch_dtype = torch_dtype config.torch_dtype = torch_dtype
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype): if torch_dtype is None:
torch_dtype = getattr(torch, torch_dtype)
elif torch_dtype is None:
torch_dtype = torch.float32 torch_dtype = torch.float32
else: else:
raise ValueError( raise ValueError(
@@ -1269,7 +1269,7 @@ def _get_torch_dtype(
dtype_orig = cls._set_default_torch_dtype(torch_dtype) dtype_orig = cls._set_default_torch_dtype(torch_dtype)
else: else:
# set fp32 as the default dtype for BC # set fp32 as the default dtype for BC
default_dtype = str(torch.get_default_dtype()).split(".")[-1] default_dtype = torch.get_default_dtype()
config.torch_dtype = default_dtype config.torch_dtype = default_dtype
for key in config.sub_configs.keys(): for key in config.sub_configs.keys():
value = getattr(config, key) value = getattr(config, key)

View File

@@ -482,9 +482,11 @@ class ModelUtilsTest(TestCasePlus):
# test that from_pretrained works with torch_dtype being strings like "float32" for PyTorch backend # test that from_pretrained works with torch_dtype being strings like "float32" for PyTorch backend
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float32") model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float32")
self.assertEqual(model.dtype, torch.float32) self.assertEqual(model.dtype, torch.float32)
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float16") model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float16")
self.assertEqual(model.dtype, torch.float16) self.assertEqual(model.dtype, torch.float16)
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type # torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@@ -495,14 +497,22 @@ class ModelUtilsTest(TestCasePlus):
Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config
Tiny-Llava has saved auto dtype as `torch.float32` for all modules. Tiny-Llava has saved auto dtype as `torch.float32` for all modules.
""" """
# Load without dtype specified
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA)
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.float32)
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
# should be able to set torch_dtype as a simple string and the model loads it correctly # should be able to set torch_dtype as a simple string and the model loads it correctly
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32") model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32")
self.assertEqual(model.language_model.dtype, torch.float32) self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.float32) self.assertEqual(model.vision_tower.dtype, torch.float32)
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype=torch.float16) model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype=torch.float16)
self.assertEqual(model.language_model.dtype, torch.float16) self.assertEqual(model.language_model.dtype, torch.float16)
self.assertEqual(model.vision_tower.dtype, torch.float16) self.assertEqual(model.vision_tower.dtype, torch.float16)
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
# should be able to set torch_dtype as a dict for each sub-config # should be able to set torch_dtype as a dict for each sub-config
model = LlavaForConditionalGeneration.from_pretrained( model = LlavaForConditionalGeneration.from_pretrained(
@@ -511,6 +521,7 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(model.language_model.dtype, torch.float32) self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.float16) self.assertEqual(model.vision_tower.dtype, torch.float16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16) self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
# should be able to set the values as torch.dtype (not str) # should be able to set the values as torch.dtype (not str)
model = LlavaForConditionalGeneration.from_pretrained( model = LlavaForConditionalGeneration.from_pretrained(
@@ -519,6 +530,7 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(model.language_model.dtype, torch.float32) self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.float16) self.assertEqual(model.vision_tower.dtype, torch.float16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16) self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
# should be able to set the values in configs directly and pass it to `from_pretrained` # should be able to set the values in configs directly and pass it to `from_pretrained`
config = copy.deepcopy(model.config) config = copy.deepcopy(model.config)
@@ -529,6 +541,7 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(model.language_model.dtype, torch.float32) self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.bfloat16) self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16) self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16)
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
# but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what # but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"] LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
@@ -536,6 +549,7 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(model.language_model.dtype, torch.float32) self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.bfloat16) self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32) self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type # torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
with self.assertRaises(ValueError): with self.assertRaises(ValueError):