fix AutoModel.from_pretrained(..., torch_dtype=...) (#13209)
* fix AutoModel.from_pretrained(..., torch_dtype=...) * fix to_diff_dict * add better test * torch is not always available when a model has self.torch_dtype
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
import copy
|
||||
import gc
|
||||
import inspect
|
||||
import json
|
||||
import os.path
|
||||
import random
|
||||
import tempfile
|
||||
@@ -1663,9 +1664,11 @@ class ModelUtilsTest(TestCasePlus):
|
||||
@require_torch
|
||||
def test_model_from_pretrained_torch_dtype(self):
|
||||
# test that the model can be instantiated with dtype of either
|
||||
# 1. config.torch_dtype setting in the saved model (priority)
|
||||
# 2. via autodiscovery by looking at model weights
|
||||
# 1. explicit from_pretrained's torch_dtype argument
|
||||
# 2. via autodiscovery by looking at model weights (torch_dtype="auto")
|
||||
# so if a model.half() was saved, we want it to be instantiated as such.
|
||||
#
|
||||
# test an explicit model class, but also AutoModel separately as the latter goes through a different code path
|
||||
model_path = self.get_auto_remove_tmp_dir()
|
||||
|
||||
# baseline - we know TINY_T5 is fp32 model
|
||||
@@ -1688,13 +1691,26 @@ class ModelUtilsTest(TestCasePlus):
|
||||
model = model.half()
|
||||
model.save_pretrained(model_path)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||
self.assertEqual(model.config.torch_dtype, "float16") # tests `config.torch_dtype` saving
|
||||
self.assertEqual(model.config.torch_dtype, torch.float16)
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
# tests `config.torch_dtype` saving
|
||||
with open(f"{model_path}/config.json") as f:
|
||||
config_dict = json.load(f)
|
||||
self.assertEqual(config_dict["torch_dtype"], "float16")
|
||||
|
||||
# test fp16 save_pretrained, loaded with the explicit fp16
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
# test AutoModel separately as it goes through a different path
|
||||
# test auto-detection
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="auto")
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
# test forcing an explicit dtype
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user