[from_pretrained] extend torch_dtype="auto" to look up config.torch_dtype first, expand docs (#21524)
* [from_pretrained] expand on torch_dtype entry * fold 4 into 1 * style * support torch_dtype='config' plus tests * style * oops * fold config into auto, fix bug * fix check * better log * better log * clean up
This commit is contained in:
@@ -1904,7 +1904,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||||
identifier allowed by git.
|
identifier allowed by git.
|
||||||
|
|
||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>".
|
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>".
|
||||||
@@ -1932,8 +1931,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||||
This is an experimental feature and a subject to change at any moment.
|
This is an experimental feature and a subject to change at any moment.
|
||||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||||
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
Override the default `torch.dtype` and load the model under a specific `dtype`. The different options
|
||||||
will be automatically derived from the model's weights.
|
are:
|
||||||
|
|
||||||
|
1. `torch.float16` or `torch.bfloat16` or `torch.float`: load in a specified
|
||||||
|
`dtype`, ignoring the model's `config.torch_dtype` if one exists. If not specified
|
||||||
|
- the model will get loaded in `torch.float` (fp32).
|
||||||
|
|
||||||
|
2. `"auto"` - A `torch_dtype` entry in the `config.json` file of the model will be
|
||||||
|
attempted to be used. If this entry isn't found then next check the `dtype` of the first weight in
|
||||||
|
the checkpoint that's of a floating point type and use that as `dtype`. This will load the model
|
||||||
|
using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how
|
||||||
|
the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
For some models the `dtype` they were trained in is unknown - you may try to check the model's paper or
|
||||||
|
reach out to the authors and ask them to add this information to the model's card and to insert the
|
||||||
|
`torch_dtype` entry in `config.json` on the hub.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||||
@@ -2098,10 +2116,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
" bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes` or"
|
" bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes` or"
|
||||||
" pip install bitsandbytes` "
|
" pip install bitsandbytes` "
|
||||||
)
|
)
|
||||||
if torch_dtype == "auto" or torch_dtype != torch.float16:
|
if torch_dtype != torch.float16:
|
||||||
# We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
|
# We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
|
||||||
|
logger.warning(
|
||||||
|
f"Overriding torch_dtype={torch_dtype} with `torch_dtype=torch.float16` due to "
|
||||||
|
"requirements of `bitsandbytes` to enable model loading in mixed int8. "
|
||||||
|
"Either pass torch_dtype=torch.float16 or don't pass this argument at all to remove this warning."
|
||||||
|
)
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
logger.info("Loading the model in mixed int8 - forcing the weights to be casted in float16")
|
|
||||||
if device_map is None:
|
if device_map is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"A device map needs to be passed to run convert models into mixed-int8 format. Please run"
|
"A device map needs to be passed to run convert models into mixed-int8 format. Please run"
|
||||||
@@ -2388,6 +2411,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if torch_dtype is not None:
|
if torch_dtype is not None:
|
||||||
if isinstance(torch_dtype, str):
|
if isinstance(torch_dtype, str):
|
||||||
if torch_dtype == "auto":
|
if torch_dtype == "auto":
|
||||||
|
if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
|
||||||
|
torch_dtype = config.torch_dtype
|
||||||
|
logger.info(f"Will use torch_dtype={torch_dtype} as defined in model's config object")
|
||||||
|
else:
|
||||||
if is_sharded and "dtype" in sharded_metadata:
|
if is_sharded and "dtype" in sharded_metadata:
|
||||||
torch_dtype = sharded_metadata["dtype"]
|
torch_dtype = sharded_metadata["dtype"]
|
||||||
elif not is_sharded:
|
elif not is_sharded:
|
||||||
@@ -2396,9 +2423,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
one_state_dict = load_state_dict(resolved_archive_file[0])
|
one_state_dict = load_state_dict(resolved_archive_file[0])
|
||||||
torch_dtype = get_state_dict_dtype(one_state_dict)
|
torch_dtype = get_state_dict_dtype(one_state_dict)
|
||||||
del one_state_dict # free CPU memory
|
del one_state_dict # free CPU memory
|
||||||
|
logger.info(
|
||||||
|
"Since the `torch_dtype` attribute can't be found in model's config object, "
|
||||||
|
"will use torch_dtype={torch_dtype} as derived from model's weights"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`torch_dtype` can be either a `torch.dtype` or `auto`, but received {torch_dtype}"
|
f'`torch_dtype` can be either `torch.dtype` or `"auto"`, but received {torch_dtype}'
|
||||||
)
|
)
|
||||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Factory function to build auto-model classes."""
|
"""Factory function to build auto-model classes."""
|
||||||
|
import copy
|
||||||
import importlib
|
import importlib
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
@@ -431,12 +432,18 @@ class _BaseAutoModelClass:
|
|||||||
]
|
]
|
||||||
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
|
kwargs_copy = copy.deepcopy(kwargs)
|
||||||
|
# ensure not to pollute the config object with torch_dtype="auto" - since it's
|
||||||
|
# meaningless in the context of the config object - torch.dtype values are acceptable
|
||||||
|
if kwargs_copy.get("torch_dtype", None) == "auto":
|
||||||
|
_ = kwargs_copy.pop("torch_dtype")
|
||||||
|
|
||||||
config, kwargs = AutoConfig.from_pretrained(
|
config, kwargs = AutoConfig.from_pretrained(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
return_unused_kwargs=True,
|
return_unused_kwargs=True,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
**hub_kwargs,
|
**hub_kwargs,
|
||||||
**kwargs,
|
**kwargs_copy,
|
||||||
)
|
)
|
||||||
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
|
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
|
||||||
if not trust_remote_code:
|
if not trust_remote_code:
|
||||||
|
|||||||
@@ -2785,7 +2785,6 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
self.assertTrue(torch.equal(p1, p2))
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
@require_torch
|
|
||||||
def test_model_from_config_torch_dtype(self):
|
def test_model_from_config_torch_dtype(self):
|
||||||
# test that the model can be instantiated with dtype of user's choice - as long as it's a
|
# test that the model can be instantiated with dtype of user's choice - as long as it's a
|
||||||
# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
|
# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
|
||||||
@@ -2804,7 +2803,6 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
model = AutoModel.from_config(config, torch_dtype=torch.int64)
|
model = AutoModel.from_config(config, torch_dtype=torch.int64)
|
||||||
|
|
||||||
@require_torch
|
|
||||||
def test_model_from_pretrained_torch_dtype(self):
|
def test_model_from_pretrained_torch_dtype(self):
|
||||||
# test that the model can be instantiated with dtype of either
|
# test that the model can be instantiated with dtype of either
|
||||||
# 1. explicit from_pretrained's torch_dtype argument
|
# 1. explicit from_pretrained's torch_dtype argument
|
||||||
@@ -2818,11 +2816,25 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
|
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
|
||||||
self.assertEqual(model.dtype, torch.float32)
|
self.assertEqual(model.dtype, torch.float32)
|
||||||
|
|
||||||
|
def remove_torch_dtype(model_path):
|
||||||
|
file = f"{model_path}/config.json"
|
||||||
|
with open(file, "r", encoding="utf-8") as f:
|
||||||
|
s = json.load(f)
|
||||||
|
s.pop("torch_dtype")
|
||||||
|
with open(file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(s, f)
|
||||||
|
|
||||||
# test the default fp32 save_pretrained => from_pretrained cycle
|
# test the default fp32 save_pretrained => from_pretrained cycle
|
||||||
model.save_pretrained(model_path)
|
model.save_pretrained(model_path)
|
||||||
model = T5ForConditionalGeneration.from_pretrained(model_path)
|
model = T5ForConditionalGeneration.from_pretrained(model_path)
|
||||||
self.assertEqual(model.dtype, torch.float32)
|
self.assertEqual(model.dtype, torch.float32)
|
||||||
# test with auto-detection
|
# 1. test torch_dtype="auto" via `config.torch_dtype`
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||||
|
self.assertEqual(model.dtype, torch.float32)
|
||||||
|
# 2. test torch_dtype="auto" via auto-derivation
|
||||||
|
# now remove the torch_dtype entry from config.json and try "auto" again which should
|
||||||
|
# perform auto-derivation from weights
|
||||||
|
remove_torch_dtype(model_path)
|
||||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||||
self.assertEqual(model.dtype, torch.float32)
|
self.assertEqual(model.dtype, torch.float32)
|
||||||
|
|
||||||
@@ -2833,24 +2845,32 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
# test fp16 save_pretrained, loaded with auto-detection
|
# test fp16 save_pretrained, loaded with auto-detection
|
||||||
model = model.half()
|
model = model.half()
|
||||||
model.save_pretrained(model_path)
|
model.save_pretrained(model_path)
|
||||||
|
# 1. test torch_dtype="auto" via `config.torch_dtype`
|
||||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||||
self.assertEqual(model.config.torch_dtype, torch.float16)
|
self.assertEqual(model.config.torch_dtype, torch.float16)
|
||||||
self.assertEqual(model.dtype, torch.float16)
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
# tests `config.torch_dtype` saving
|
# tests `config.torch_dtype` saving
|
||||||
with open(f"{model_path}/config.json") as f:
|
with open(f"{model_path}/config.json") as f:
|
||||||
config_dict = json.load(f)
|
config_dict = json.load(f)
|
||||||
self.assertEqual(config_dict["torch_dtype"], "float16")
|
self.assertEqual(config_dict["torch_dtype"], "float16")
|
||||||
|
# 2. test torch_dtype="auto" via auto-derivation
|
||||||
|
# now same with using config info
|
||||||
|
remove_torch_dtype(model_path)
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||||
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
# test fp16 save_pretrained, loaded with the explicit fp16
|
# test fp16 save_pretrained, loaded with the explicit fp16
|
||||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
|
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||||
self.assertEqual(model.dtype, torch.float16)
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
# test AutoModel separately as it goes through a different path
|
# test AutoModel separately as it goes through a different path
|
||||||
# test auto-detection
|
# test auto-detection - as currently TINY_T5 doesn't have torch_dtype entry
|
||||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="auto")
|
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="auto")
|
||||||
|
# test that the config object didn't get polluted with torch_dtype="auto"
|
||||||
|
# there was a bug that after this call we ended up with config.torch_dtype=="auto"
|
||||||
|
self.assertNotEqual(model.config.torch_dtype, "auto")
|
||||||
|
# now test the outcome
|
||||||
self.assertEqual(model.dtype, torch.float32)
|
self.assertEqual(model.dtype, torch.float32)
|
||||||
# test forcing an explicit dtype
|
|
||||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
|
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
|
||||||
self.assertEqual(model.dtype, torch.float16)
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user