[AutoModel] fix torch_dtype=auto in from_pretrained (#23379)

* [automodel] fix torch_dtype=auto in from_pretrained

* add test

* fix logic

* Update src/transformers/models/auto/auto_factory.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman
2023-05-16 10:21:42 -07:00
committed by GitHub
parent 8a58809312
commit bbbc5c15d4
2 changed files with 13 additions and 4 deletions

View File

@@ -2920,6 +2920,10 @@ class ModelUtilsTest(TestCasePlus):
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
self.assertEqual(model.dtype, torch.float16)
# 3. now retest that AutoModel behaves the same wrt torch_dtype="auto" as T5ForConditionalGeneration
model = AutoModel.from_pretrained(model_path, torch_dtype="auto")
self.assertEqual(model.dtype, torch.float16)
# test fp16 save_pretrained, loaded with the explicit fp16
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)