[from_pretrained] extend torch_dtype="auto" to look up config.torch_dtype first, expand docs (#21524)
* [from_pretrained] expand on torch_dtype entry * fold 4 into 1 * style * support torch_dtype='config' plus tests * style * oops * fold config into auto, fix bug * fix check * better log * better log * clean up
This commit is contained in:
@@ -2785,7 +2785,6 @@ class ModelUtilsTest(TestCasePlus):
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
@require_torch
|
||||
def test_model_from_config_torch_dtype(self):
|
||||
# test that the model can be instantiated with dtype of user's choice - as long as it's a
|
||||
# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
|
||||
@@ -2804,7 +2803,6 @@ class ModelUtilsTest(TestCasePlus):
|
||||
with self.assertRaises(ValueError):
|
||||
model = AutoModel.from_config(config, torch_dtype=torch.int64)
|
||||
|
||||
@require_torch
|
||||
def test_model_from_pretrained_torch_dtype(self):
|
||||
# test that the model can be instantiated with dtype of either
|
||||
# 1. explicit from_pretrained's torch_dtype argument
|
||||
@@ -2818,11 +2816,25 @@ class ModelUtilsTest(TestCasePlus):
|
||||
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
|
||||
def remove_torch_dtype(model_path):
|
||||
file = f"{model_path}/config.json"
|
||||
with open(file, "r", encoding="utf-8") as f:
|
||||
s = json.load(f)
|
||||
s.pop("torch_dtype")
|
||||
with open(file, "w", encoding="utf-8") as f:
|
||||
json.dump(s, f)
|
||||
|
||||
# test the default fp32 save_pretrained => from_pretrained cycle
|
||||
model.save_pretrained(model_path)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path)
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
# test with auto-detection
|
||||
# 1. test torch_dtype="auto" via `config.torch_dtype`
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
# 2. test torch_dtype="auto" via auto-derivation
|
||||
# now remove the torch_dtype entry from config.json and try "auto" again which should
|
||||
# perform auto-derivation from weights
|
||||
remove_torch_dtype(model_path)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
|
||||
@@ -2833,24 +2845,32 @@ class ModelUtilsTest(TestCasePlus):
|
||||
# test fp16 save_pretrained, loaded with auto-detection
|
||||
model = model.half()
|
||||
model.save_pretrained(model_path)
|
||||
# 1. test torch_dtype="auto" via `config.torch_dtype`
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||
self.assertEqual(model.config.torch_dtype, torch.float16)
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
# tests `config.torch_dtype` saving
|
||||
with open(f"{model_path}/config.json") as f:
|
||||
config_dict = json.load(f)
|
||||
self.assertEqual(config_dict["torch_dtype"], "float16")
|
||||
# 2. test torch_dtype="auto" via auto-derivation
|
||||
# now same with using config info
|
||||
remove_torch_dtype(model_path)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
# test fp16 save_pretrained, loaded with the explicit fp16
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
# test AutoModel separately as it goes through a different path
|
||||
# test auto-detection
|
||||
# test auto-detection - as currently TINY_T5 doesn't have torch_dtype entry
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="auto")
|
||||
# test that the config object didn't get polluted with torch_dtype="auto"
|
||||
# there was a bug that after this call we ended up with config.torch_dtype=="auto"
|
||||
self.assertNotEqual(model.config.torch_dtype, "auto")
|
||||
# now test the outcome
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
# test forcing an explicit dtype
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user