[modeling_utils] torch_dtype/auto floating dtype fixes (#17614)

* [modeling_utils] torch_dtype/auto fixes

* add test

* apply suggestions

* add missing fallback

* Renaming things

* Use for else

Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
This commit is contained in:
Stas Bekman
2022-06-09 10:18:26 -07:00
committed by GitHub
parent c38f4e1f1c
commit 75343de938
2 changed files with 64 additions and 4 deletions

View File

@@ -134,6 +134,7 @@ def _config_zero_init(config):
TINY_T5 = "patrickvonplaten/t5-tiny-random"
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
@require_torch
@@ -2557,6 +2558,10 @@ class ModelUtilsTest(TestCasePlus):
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)
# test model whose first param is not of a floating type, but int
model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto")
self.assertEqual(model.dtype, torch.float32)
def test_no_super_init_config_and_model(self):
config = NoSuperInitConfig(attribute=32)
model = NoSuperInitModel(config)