[modeling_utils] torch_dtype/auto floating dtype fixes (#17614)
* [modeling_utils] torch_dtype/auto fixes * add test * apply suggestions * add missing fallback * Renaming things * Use for else Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
This commit is contained in:
@@ -134,6 +134,7 @@ def _config_zero_init(config):
|
||||
|
||||
|
||||
TINY_T5 = "patrickvonplaten/t5-tiny-random"
|
||||
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -2557,6 +2558,10 @@ class ModelUtilsTest(TestCasePlus):
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
# test model whose first param is not of a floating type, but int
|
||||
model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto")
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
|
||||
def test_no_super_init_config_and_model(self):
|
||||
config = NoSuperInitConfig(attribute=32)
|
||||
model = NoSuperInitModel(config)
|
||||
|
||||
Reference in New Issue
Block a user