Enable different torch dtype in sub models (#34873)
* fix * fix test * add tests * add more tests * fix tests * supposed to be a torch.dtype test * handle BC and make fp32 default
This commit is contained in:
committed by
GitHub
parent
87089176d9
commit
84a6789145
@@ -37,6 +37,7 @@ from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForSequenceClassification,
|
||||
LlavaForConditionalGeneration,
|
||||
OwlViTForObjectDetection,
|
||||
PretrainedConfig,
|
||||
is_torch_available,
|
||||
@@ -300,6 +301,7 @@ TINY_T5 = "patrickvonplaten/t5-tiny-random"
|
||||
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
|
||||
TINY_MISTRAL = "hf-internal-testing/tiny-random-MistralForCausalLM"
|
||||
TINY_IMAGE_CLASSIF = "hf-internal-testing/tiny-random-SiglipForImageClassification"
|
||||
TINY_LLAVA = "hf-internal-testing/tiny-random-LlavaForConditionalGeneration"
|
||||
|
||||
LOG = logging.get_logger(__name__)
|
||||
|
||||
@@ -460,6 +462,59 @@ class ModelUtilsTest(TestCasePlus):
|
||||
with self.assertRaises(ValueError):
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64")
|
||||
|
||||
def test_model_from_config_torch_dtype_composite(self):
|
||||
"""
|
||||
Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config
|
||||
"""
|
||||
# should be able to set torch_dtype as a simple string and the model loads it correctly
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32")
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.float32)
|
||||
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float16")
|
||||
self.assertEqual(model.language_model.dtype, torch.float16)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.float16)
|
||||
|
||||
# should be able to set torch_dtype as a dict for each sub-config
|
||||
model = LlavaForConditionalGeneration.from_pretrained(
|
||||
TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "float16", "": "bfloat16"}
|
||||
)
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.float16)
|
||||
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
|
||||
|
||||
# should be able to set the values as torch.dtype (not str)
|
||||
model = LlavaForConditionalGeneration.from_pretrained(
|
||||
TINY_LLAVA, torch_dtype={"text_config": torch.float32, "vision_config": torch.float16, "": torch.bfloat16}
|
||||
)
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.float16)
|
||||
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
|
||||
|
||||
# should be able to set the values in configs directly and pass it to `from_pretrained`
|
||||
config = copy.deepcopy(model.config)
|
||||
config.text_config.torch_dtype = torch.float32
|
||||
config.vision_config.torch_dtype = torch.bfloat16
|
||||
config.torch_dtype = torch.float16
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
|
||||
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16)
|
||||
|
||||
# but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
|
||||
LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
|
||||
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
|
||||
|
||||
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
|
||||
with self.assertRaises(ValueError):
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="int64")
|
||||
model = LlavaForConditionalGeneration.from_pretrained(
|
||||
TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "int64", "": "float16"}
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_model_from_pretrained_meta_device(self):
|
||||
def is_on_meta(model_id, dtype):
|
||||
|
||||
Reference in New Issue
Block a user