fix AutoModel.from_pretrained(..., torch_dtype=...) (#13209)

* fix AutoModel.from_pretrained(..., torch_dtype=...)

* fix to_diff_dict

* add better test

* torch is not always available when a model has self.torch_dtype
This commit is contained in:
Stas Bekman
2021-08-24 02:43:41 -07:00
committed by GitHub
parent 39db2f3c19
commit 5c6eca71a9
2 changed files with 46 additions and 4 deletions

View File

@@ -16,6 +16,7 @@
import copy
import gc
import inspect
import json
import os.path
import random
import tempfile
@@ -1663,9 +1664,11 @@ class ModelUtilsTest(TestCasePlus):
@require_torch
def test_model_from_pretrained_torch_dtype(self):
# test that the model can be instantiated with dtype of either
# 1. config.torch_dtype setting in the saved model (priority)
# 2. via autodiscovery by looking at model weights
# 1. explicit from_pretrained's torch_dtype argument
# 2. via autodiscovery by looking at model weights (torch_dtype="auto")
# so if a model.half() was saved, we want it to be instantiated as such.
#
# test an explicit model class, but also AutoModel separately as the latter goes through a different code path
model_path = self.get_auto_remove_tmp_dir()
# baseline - we know TINY_T5 is fp32 model
@@ -1688,13 +1691,26 @@ class ModelUtilsTest(TestCasePlus):
model = model.half()
model.save_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
self.assertEqual(model.config.torch_dtype, "float16") # tests `config.torch_dtype` saving
self.assertEqual(model.config.torch_dtype, torch.float16)
self.assertEqual(model.dtype, torch.float16)
# tests `config.torch_dtype` saving
with open(f"{model_path}/config.json") as f:
config_dict = json.load(f)
self.assertEqual(config_dict["torch_dtype"], "float16")
# test fp16 save_pretrained, loaded with the explicit fp16
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)
# test AutoModel separately as it goes through a different path
# test auto-detection
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="auto")
self.assertEqual(model.dtype, torch.float32)
# test forcing an explicit dtype
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)
@require_torch
@is_staging_test