[AutoModel] fix torch_dtype=auto in from_pretrained (#23379)
* [automodel] fix torch_dtype=auto in from_pretrained * add test * fix logic * Update src/transformers/models/auto/auto_factory.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -2920,6 +2920,10 @@ class ModelUtilsTest(TestCasePlus):
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
# 3. now retest that AutoModel behaves the same wrt torch_dtype="auto" as T5ForConditionalGeneration
|
||||
model = AutoModel.from_pretrained(model_path, torch_dtype="auto")
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
# test fp16 save_pretrained, loaded with the explicit fp16
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
Reference in New Issue
Block a user