Update config.torch_dtype correctly (#36679)

* fix

* style

* new test
This commit is contained in:
Marc Sun
2025-03-13 12:08:02 +01:00
committed by GitHub
parent c4161238bd
commit fbb18ce68b
2 changed files with 11 additions and 0 deletions

View File

@@ -36,6 +36,7 @@ from transformers import (
AutoModel,
AutoModelForImageClassification,
AutoModelForSequenceClassification,
CLIPTextModelWithProjection,
DynamicCache,
LlavaForConditionalGeneration,
MistralForCausalLM,
@@ -617,6 +618,14 @@ class ModelUtilsTest(TestCasePlus):
model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto")
self.assertEqual(model.dtype, torch.float32)
# test model that init the model with _from_config
model = CLIPTextModelWithProjection.from_pretrained(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all",
subfolder="text_encoder",
torch_dtype=torch.bfloat16,
)
self.assertEqual(model.dtype, torch.bfloat16)
def test_model_from_pretrained_attn_implementation(self):
# test that the model can be instantiated with attn_implementation of either
# 1. explicit from_pretrained's attn_implementation argument