From 3a028101e91069b51629f5e74096ae78e490022b Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Thu, 27 Jun 2024 18:41:49 +0800 Subject: [PATCH] [QoL] Allow dtype str for torch_dtype arg of from_pretrained (#31590) * Allow dtype str for torch_dtype in from_pretrained * Update docstring * Add tests for str torch_dtype --- src/transformers/modeling_utils.py | 6 +++++- tests/test_modeling_utils.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f7b0db6d77..c991c1c95b 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2958,6 +2958,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32. + 3. A string that is a valid `torch.dtype`. E.g. "float32" loads the model in `torch.float32`, "float16" loads in `torch.float16` etc. + For some models the `dtype` they were trained in is unknown - you may try to check the model's paper or @@ -3661,9 +3663,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix "Since the `torch_dtype` attribute can't be found in model's config object, " "will use torch_dtype={torch_dtype} as derived from model's weights" ) + elif hasattr(torch, torch_dtype): + torch_dtype = getattr(torch, torch_dtype) else: raise ValueError( - f'`torch_dtype` can be either `torch.dtype` or `"auto"`, but received {torch_dtype}' + f'`torch_dtype` can be one of: `torch.dtype`, `"auto"` or a string of a valid `torch.dtype`, but received {torch_dtype}' ) dtype_orig = cls._set_default_torch_dtype(torch_dtype) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index f5b30d5033..758fe4d1fd 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -445,6 +445,18 @@ class ModelUtilsTest(TestCasePlus): with self.assertRaises(ValueError): model = AutoModel.from_config(config, torch_dtype=torch.int64) + def test_model_from_config_torch_dtype_str(self): + # test that from_pretrained works with torch_dtype being strings like "float32" for PyTorch backend + model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float32") + self.assertEqual(model.dtype, torch.float32) + + model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float16") + self.assertEqual(model.dtype, torch.float16) + + # torch.set_default_dtype() supports only float dtypes, so will fail with non-float type + with self.assertRaises(ValueError): + model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64") + def test_model_from_pretrained_torch_dtype(self): # test that the model can be instantiated with dtype of either # 1. explicit from_pretrained's torch_dtype argument