@@ -1223,10 +1223,12 @@ def _get_torch_dtype(
|
||||
)
|
||||
elif hasattr(torch, torch_dtype):
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
config.torch_dtype = torch_dtype
|
||||
for sub_config_key in config.sub_configs.keys():
|
||||
sub_config = getattr(config, sub_config_key)
|
||||
sub_config.torch_dtype = torch_dtype
|
||||
elif isinstance(torch_dtype, torch.dtype):
|
||||
config.torch_dtype = torch_dtype
|
||||
for sub_config_key in config.sub_configs.keys():
|
||||
sub_config = getattr(config, sub_config_key)
|
||||
sub_config.torch_dtype = torch_dtype
|
||||
|
||||
@@ -36,6 +36,7 @@ from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForSequenceClassification,
|
||||
CLIPTextModelWithProjection,
|
||||
DynamicCache,
|
||||
LlavaForConditionalGeneration,
|
||||
MistralForCausalLM,
|
||||
@@ -617,6 +618,14 @@ class ModelUtilsTest(TestCasePlus):
|
||||
model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto")
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
|
||||
# test model that init the model with _from_config
|
||||
model = CLIPTextModelWithProjection.from_pretrained(
|
||||
"hf-internal-testing/diffusers-stable-diffusion-tiny-all",
|
||||
subfolder="text_encoder",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
self.assertEqual(model.dtype, torch.bfloat16)
|
||||
|
||||
def test_model_from_pretrained_attn_implementation(self):
|
||||
# test that the model can be instantiated with attn_implementation of either
|
||||
# 1. explicit from_pretrained's attn_implementation argument
|
||||
|
||||
Reference in New Issue
Block a user