@@ -1223,10 +1223,12 @@ def _get_torch_dtype(
|
|||||||
)
|
)
|
||||||
elif hasattr(torch, torch_dtype):
|
elif hasattr(torch, torch_dtype):
|
||||||
torch_dtype = getattr(torch, torch_dtype)
|
torch_dtype = getattr(torch, torch_dtype)
|
||||||
|
config.torch_dtype = torch_dtype
|
||||||
for sub_config_key in config.sub_configs.keys():
|
for sub_config_key in config.sub_configs.keys():
|
||||||
sub_config = getattr(config, sub_config_key)
|
sub_config = getattr(config, sub_config_key)
|
||||||
sub_config.torch_dtype = torch_dtype
|
sub_config.torch_dtype = torch_dtype
|
||||||
elif isinstance(torch_dtype, torch.dtype):
|
elif isinstance(torch_dtype, torch.dtype):
|
||||||
|
config.torch_dtype = torch_dtype
|
||||||
for sub_config_key in config.sub_configs.keys():
|
for sub_config_key in config.sub_configs.keys():
|
||||||
sub_config = getattr(config, sub_config_key)
|
sub_config = getattr(config, sub_config_key)
|
||||||
sub_config.torch_dtype = torch_dtype
|
sub_config.torch_dtype = torch_dtype
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from transformers import (
|
|||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForImageClassification,
|
AutoModelForImageClassification,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
|
CLIPTextModelWithProjection,
|
||||||
DynamicCache,
|
DynamicCache,
|
||||||
LlavaForConditionalGeneration,
|
LlavaForConditionalGeneration,
|
||||||
MistralForCausalLM,
|
MistralForCausalLM,
|
||||||
@@ -617,6 +618,14 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto")
|
model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto")
|
||||||
self.assertEqual(model.dtype, torch.float32)
|
self.assertEqual(model.dtype, torch.float32)
|
||||||
|
|
||||||
|
# test model that init the model with _from_config
|
||||||
|
model = CLIPTextModelWithProjection.from_pretrained(
|
||||||
|
"hf-internal-testing/diffusers-stable-diffusion-tiny-all",
|
||||||
|
subfolder="text_encoder",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
self.assertEqual(model.dtype, torch.bfloat16)
|
||||||
|
|
||||||
def test_model_from_pretrained_attn_implementation(self):
|
def test_model_from_pretrained_attn_implementation(self):
|
||||||
# test that the model can be instantiated with attn_implementation of either
|
# test that the model can be instantiated with attn_implementation of either
|
||||||
# 1. explicit from_pretrained's attn_implementation argument
|
# 1. explicit from_pretrained's attn_implementation argument
|
||||||
|
|||||||
Reference in New Issue
Block a user