Fix: dtype cannot be str (#36262)
* fix * this wan't supposed to be here, revert * refine tests a bit more
This commit is contained in:
committed by
GitHub
parent
3f9ff19b4e
commit
523f6e743c
@@ -482,9 +482,11 @@ class ModelUtilsTest(TestCasePlus):
|
||||
# test that from_pretrained works with torch_dtype being strings like "float32" for PyTorch backend
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float32")
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
|
||||
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float16")
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
|
||||
|
||||
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -495,14 +497,22 @@ class ModelUtilsTest(TestCasePlus):
|
||||
Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config
|
||||
Tiny-Llava has saved auto dtype as `torch.float32` for all modules.
|
||||
"""
|
||||
# Load without dtype specified
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA)
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.float32)
|
||||
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
|
||||
|
||||
# should be able to set torch_dtype as a simple string and the model loads it correctly
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32")
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.float32)
|
||||
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
|
||||
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype=torch.float16)
|
||||
self.assertEqual(model.language_model.dtype, torch.float16)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.float16)
|
||||
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
|
||||
|
||||
# should be able to set torch_dtype as a dict for each sub-config
|
||||
model = LlavaForConditionalGeneration.from_pretrained(
|
||||
@@ -511,6 +521,7 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.float16)
|
||||
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
|
||||
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
|
||||
|
||||
# should be able to set the values as torch.dtype (not str)
|
||||
model = LlavaForConditionalGeneration.from_pretrained(
|
||||
@@ -519,6 +530,7 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.float16)
|
||||
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
|
||||
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
|
||||
|
||||
# should be able to set the values in configs directly and pass it to `from_pretrained`
|
||||
config = copy.deepcopy(model.config)
|
||||
@@ -529,6 +541,7 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
|
||||
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16)
|
||||
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
|
||||
|
||||
# but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
|
||||
LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
|
||||
@@ -536,6 +549,7 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
|
||||
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
|
||||
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
|
||||
|
||||
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
|
||||
with self.assertRaises(ValueError):
|
||||
|
||||
Reference in New Issue
Block a user