Enable different torch dtype in sub models (#34873)

* fix

* fix test

* add tests

* add more tests

* fix tests

* supposed to be a torch.dtype test

* handle BC and make fp32 default
This commit is contained in:
Raushan Turganbay
2025-01-13 13:42:08 +01:00
committed by GitHub
parent 87089176d9
commit 84a6789145
5 changed files with 155 additions and 68 deletions

View File

@@ -227,6 +227,7 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
pipeline_model_mapping = {"image-text-to-text": Qwen2VLForConditionalGeneration}
test_pruning = False
test_head_masking = False
_is_composite = True
def setUp(self):
self.model_tester = Qwen2VLVisionText2TextModelTester(self)

View File

@@ -37,6 +37,7 @@ from transformers import (
AutoModel,
AutoModelForImageClassification,
AutoModelForSequenceClassification,
LlavaForConditionalGeneration,
OwlViTForObjectDetection,
PretrainedConfig,
is_torch_available,
@@ -300,6 +301,7 @@ TINY_T5 = "patrickvonplaten/t5-tiny-random"
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
TINY_MISTRAL = "hf-internal-testing/tiny-random-MistralForCausalLM"
TINY_IMAGE_CLASSIF = "hf-internal-testing/tiny-random-SiglipForImageClassification"
TINY_LLAVA = "hf-internal-testing/tiny-random-LlavaForConditionalGeneration"
LOG = logging.get_logger(__name__)
@@ -460,6 +462,59 @@ class ModelUtilsTest(TestCasePlus):
with self.assertRaises(ValueError):
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64")
def test_model_from_config_torch_dtype_composite(self):
"""
Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config
"""
# should be able to set torch_dtype as a simple string and the model loads it correctly
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32")
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.float32)
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float16")
self.assertEqual(model.language_model.dtype, torch.float16)
self.assertEqual(model.vision_tower.dtype, torch.float16)
# should be able to set torch_dtype as a dict for each sub-config
model = LlavaForConditionalGeneration.from_pretrained(
TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "float16", "": "bfloat16"}
)
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.float16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
# should be able to set the values as torch.dtype (not str)
model = LlavaForConditionalGeneration.from_pretrained(
TINY_LLAVA, torch_dtype={"text_config": torch.float32, "vision_config": torch.float16, "": torch.bfloat16}
)
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.float16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
# should be able to set the values in configs directly and pass it to `from_pretrained`
config = copy.deepcopy(model.config)
config.text_config.torch_dtype = torch.float32
config.vision_config.torch_dtype = torch.bfloat16
config.torch_dtype = torch.float16
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16)
# but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
with self.assertRaises(ValueError):
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="int64")
model = LlavaForConditionalGeneration.from_pretrained(
TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "int64", "": "float16"}
)
@require_torch
def test_model_from_pretrained_meta_device(self):
def is_on_meta(model_id, dtype):