[QoL] Allow dtype str for torch_dtype arg of from_pretrained (#31590)

* Allow dtype str for torch_dtype in from_pretrained

* Update docstring

* Add tests for str torch_dtype
This commit is contained in:
Billy Cao
2024-06-27 18:41:49 +08:00
committed by GitHub
parent 11138ca013
commit 3a028101e9
2 changed files with 17 additions and 1 deletions

View File

@@ -445,6 +445,18 @@ class ModelUtilsTest(TestCasePlus):
with self.assertRaises(ValueError):
model = AutoModel.from_config(config, torch_dtype=torch.int64)
def test_model_from_config_torch_dtype_str(self):
# test that from_pretrained works with torch_dtype being strings like "float32" for PyTorch backend
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float32")
self.assertEqual(model.dtype, torch.float32)
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float16")
self.assertEqual(model.dtype, torch.float16)
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
with self.assertRaises(ValueError):
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64")
def test_model_from_pretrained_torch_dtype(self):
# test that the model can be instantiated with dtype of either
# 1. explicit from_pretrained's torch_dtype argument