[QoL] Allow dtype str for torch_dtype arg of from_pretrained (#31590)
* Allow dtype str for torch_dtype in from_pretrained * Update docstring * Add tests for str torch_dtype
This commit is contained in:
@@ -445,6 +445,18 @@ class ModelUtilsTest(TestCasePlus):
|
||||
with self.assertRaises(ValueError):
|
||||
model = AutoModel.from_config(config, torch_dtype=torch.int64)
|
||||
|
||||
def test_model_from_config_torch_dtype_str(self):
|
||||
# test that from_pretrained works with torch_dtype being strings like "float32" for PyTorch backend
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float32")
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float16")
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
|
||||
with self.assertRaises(ValueError):
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64")
|
||||
|
||||
def test_model_from_pretrained_torch_dtype(self):
|
||||
# test that the model can be instantiated with dtype of either
|
||||
# 1. explicit from_pretrained's torch_dtype argument
|
||||
|
||||
Reference in New Issue
Block a user