From fbb18ce68ba552e88a0e1c6518b3f641380f3335 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Thu, 13 Mar 2025 12:08:02 +0100 Subject: [PATCH] Update config.torch_dtype correctly (#36679) * fix * style * new test --- src/transformers/modeling_utils.py | 2 ++ tests/utils/test_modeling_utils.py | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a33f556f84..207ddafa97 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1223,10 +1223,12 @@ def _get_torch_dtype( ) elif hasattr(torch, torch_dtype): torch_dtype = getattr(torch, torch_dtype) + config.torch_dtype = torch_dtype for sub_config_key in config.sub_configs.keys(): sub_config = getattr(config, sub_config_key) sub_config.torch_dtype = torch_dtype elif isinstance(torch_dtype, torch.dtype): + config.torch_dtype = torch_dtype for sub_config_key in config.sub_configs.keys(): sub_config = getattr(config, sub_config_key) sub_config.torch_dtype = torch_dtype diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index a85b598c08..7d69073147 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -36,6 +36,7 @@ from transformers import ( AutoModel, AutoModelForImageClassification, AutoModelForSequenceClassification, + CLIPTextModelWithProjection, DynamicCache, LlavaForConditionalGeneration, MistralForCausalLM, @@ -617,6 +618,14 @@ class ModelUtilsTest(TestCasePlus): model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto") self.assertEqual(model.dtype, torch.float32) + # test model that init the model with _from_config + model = CLIPTextModelWithProjection.from_pretrained( + "hf-internal-testing/diffusers-stable-diffusion-tiny-all", + subfolder="text_encoder", + torch_dtype=torch.bfloat16, + ) + self.assertEqual(model.dtype, torch.bfloat16) + def test_model_from_pretrained_attn_implementation(self): # test that the model can be instantiated with attn_implementation of either # 1. explicit from_pretrained's attn_implementation argument