[models] respect dtype of the model when instantiating it (#12316)
* [models] respect dtype of the model when instantiating it * cleanup * cleanup * rework to handle non-float dtype * fix * switch to fp32 tiny model * improve * use dtype.is_floating_point * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix the doc * recode to use explicit torch_dtype_auto_detect, torch_dtype args * docs and tweaks * docs and tweaks * docs and tweaks * merge 2 args, add docs * fix * fix * better doc * better doc Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -1549,6 +1549,8 @@ Note: If the fp16 weights of the model can't fit onto the memory of a single GPU
|
||||
For full details on this method and other related features please refer to `Constructing Massive Models
|
||||
<https://deepspeed.readthedocs.io/en/latest/zero3.html#constructing-massive-models>`__.
|
||||
|
||||
Also when loading fp16-pretrained models, you will want to tell ``from_pretrained`` to use
|
||||
``torch_dtype=torch.float16``. For details, please, see :ref:`from_pretrained-torch-dtype`.
|
||||
|
||||
|
||||
Gathering Parameters
|
||||
|
||||
@@ -38,6 +38,37 @@ PreTrainedModel
|
||||
:members:
|
||||
|
||||
|
||||
.. _from_pretrained-torch-dtype:
|
||||
|
||||
Model Instantiation dtype
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Under Pytorch a model normally gets instantiated with ``torch.float32`` format. This can be an issue if one tries to
|
||||
load a model whose weights are in fp16, since it'd require twice as much memory. To overcome this limitation, you can
|
||||
either explicitly pass the desired ``dtype`` using ``torch_dtype`` argument:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5", torch_dtype=torch.float16)
|
||||
|
||||
or, if you want the model to always load in the most optimal memory pattern, you can use the special value ``"auto"``,
|
||||
and then ``dtype`` will be automatically derived from the model's weights:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5", torch_dtype="auto")
|
||||
|
||||
Models instantiated from scratch can also be told which ``dtype`` to use with:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
config = T5Config.from_pretrained("t5")
|
||||
model = AutoModel.from_config(config)
|
||||
|
||||
Due to Pytorch design, this functionality is only available for floating dtypes.
|
||||
|
||||
|
||||
|
||||
ModuleUtilsMixin
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -192,6 +192,12 @@ class PretrainedConfig(PushToHubMixin):
|
||||
- **tie_word_embeddings** (:obj:`bool`, `optional`, defaults to :obj:`True`) -- Whether the model's input and
|
||||
output word embeddings should be tied. Note that this is only relevant if the model has a output word
|
||||
embedding layer.
|
||||
- **torch_dtype** (:obj:`str`, `optional`) -- The :obj:`dtype` of the weights. This attribute can be used to
|
||||
initialize the model to a non-default ``dtype`` (which is normally ``float32``) and thus allow for optimal
|
||||
storage allocation. For example, if the saved model is ``float16``, ideally we want to load it back using the
|
||||
minimal amount of memory needed to load ``float16`` weights. Since the config object is stored in plain text,
|
||||
this attribute contains just the floating type string without the ``torch.`` prefix. For example, for
|
||||
``torch.float16`` ``torch_dtype`` is the ``"float16"`` string.
|
||||
|
||||
TensorFlow specific parameters
|
||||
|
||||
@@ -207,6 +213,7 @@ class PretrainedConfig(PushToHubMixin):
|
||||
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
|
||||
self.output_attentions = kwargs.pop("output_attentions", False)
|
||||
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
|
||||
self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
|
||||
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
||||
self.pruned_heads = kwargs.pop("pruned_heads", {})
|
||||
self.tie_word_embeddings = kwargs.pop(
|
||||
|
||||
@@ -111,6 +111,13 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
|
||||
raise NotImplementedError(f"init method has to be implemented for {self}")
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, **kwargs):
|
||||
"""
|
||||
All context managers that the model should be initialized under go here.
|
||||
"""
|
||||
return cls(config, **kwargs)
|
||||
|
||||
@property
|
||||
def config(self) -> PretrainedConfig:
|
||||
return self._config
|
||||
|
||||
@@ -643,6 +643,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
self.config = config
|
||||
self.name_or_path = config.name_or_path
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, **kwargs):
|
||||
"""
|
||||
All context managers that the model should be initialized under go here.
|
||||
"""
|
||||
return cls(config, **kwargs)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
{
|
||||
|
||||
@@ -23,7 +23,7 @@ from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, device, dtype, nn
|
||||
from torch import Tensor, device, nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .activations import get_activation
|
||||
@@ -201,7 +201,7 @@ class ModuleUtilsMixin:
|
||||
return get_parameter_device(self)
|
||||
|
||||
@property
|
||||
def dtype(self) -> dtype:
|
||||
def dtype(self) -> torch.dtype:
|
||||
"""
|
||||
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
||||
"""
|
||||
@@ -464,6 +464,66 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
self.config = config
|
||||
self.name_or_path = config.name_or_path
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, **kwargs):
|
||||
"""
|
||||
All context managers that the model should be initialized under go here.
|
||||
|
||||
Args:
|
||||
torch_dtype (:obj:`torch.dtype`, `optional`):
|
||||
Override the default ``torch.dtype`` and load the model under this dtype.
|
||||
"""
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
# override default dtype if needed
|
||||
dtype_orig = None
|
||||
if torch_dtype is not None:
|
||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||
# this immediately partitions the model across all gpus, to avoid the overhead in time
|
||||
# and memory copying it on CPU or each GPU first
|
||||
with deepspeed.zero.Init(config=deepspeed_config()):
|
||||
model = cls(config, **kwargs)
|
||||
else:
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
# restore default dtype if it was modified
|
||||
if dtype_orig is not None:
|
||||
torch.set_default_dtype(dtype_orig)
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
|
||||
"""
|
||||
Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
|
||||
under specific dtype.
|
||||
|
||||
Args:
|
||||
dtype (:obj:`torch.dtype`):
|
||||
a floating dtype to set to.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.dtype`: the original ``dtype`` that can be used to restore ``torch.set_default_dtype(dtype)``
|
||||
if it was modified. If it wasn't, returns :obj:`None`.
|
||||
|
||||
Note ``set_default_dtype`` currently only works with floating-point types and asserts if for example,
|
||||
``torch.int64`` is passed. So if a non-float ``dtype`` is passed this functions will throw an exception.
|
||||
"""
|
||||
if not dtype.is_floating_point:
|
||||
raise ValueError(
|
||||
f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
|
||||
)
|
||||
|
||||
logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
|
||||
dtype_orig = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
return dtype_orig
|
||||
|
||||
@property
|
||||
def base_model(self) -> nn.Module:
|
||||
"""
|
||||
@@ -876,6 +936,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Only save the model itself if we are using distributed training
|
||||
model_to_save = unwrap_model(self)
|
||||
|
||||
# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
|
||||
# we currently don't use this setting automatically, but may start to use with v5
|
||||
dtype = get_parameter_dtype(model_to_save)
|
||||
model_to_save.config.torch_dtype = str(dtype).split(".")[1]
|
||||
|
||||
# Attach architecture to the config
|
||||
model_to_save.config.architectures = [model_to_save.__class__.__name__]
|
||||
|
||||
@@ -993,6 +1058,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
Please refer to the mirror site for more information.
|
||||
_fast_init(:obj:`bool`, `optional`, defaults to `:obj:`True`):
|
||||
Whether or not to disable fast initialization.
|
||||
torch_dtype (:obj:`str` or :obj:`torch.dtype`, `optional`):
|
||||
Override the default ``torch.dtype`` and load the model under this dtype. If ``"auto"`` is passed the
|
||||
dtype will be automatically derived from the model's weights.
|
||||
|
||||
.. warning::
|
||||
|
||||
@@ -1058,6 +1126,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
_fast_init = kwargs.pop("_fast_init", True)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
from_pt = not (from_tf | from_flax)
|
||||
|
||||
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
|
||||
if from_pipeline is not None:
|
||||
@@ -1162,6 +1233,34 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
else:
|
||||
resolved_archive_file = None
|
||||
|
||||
# load pt weights early so that we know which dtype to init the model under
|
||||
if from_pt:
|
||||
if state_dict is None:
|
||||
try:
|
||||
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
||||
except Exception:
|
||||
raise OSError(
|
||||
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
|
||||
f"at '{resolved_archive_file}'"
|
||||
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
|
||||
)
|
||||
|
||||
# set dtype to instantiate the model under:
|
||||
# 1. If torch_dtype is not None, we use that dtype
|
||||
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
||||
# weights entry - we assume all weights are of the same dtype
|
||||
# we also may have config.torch_dtype available, but we won't rely on it till v5
|
||||
dtype_orig = None
|
||||
if torch_dtype is not None:
|
||||
if isinstance(torch_dtype, str):
|
||||
if torch_dtype == "auto":
|
||||
torch_dtype = next(iter(state_dict.values())).dtype
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`torch_dtype` can be either a `torch.dtype` or `auto`, but received {torch_dtype}"
|
||||
)
|
||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||
|
||||
config.name_or_path = pretrained_model_name_or_path
|
||||
|
||||
# Instantiate model.
|
||||
@@ -1178,6 +1277,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
with no_init_weights(_enable=_fast_init):
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
if from_pt:
|
||||
# restore default dtype
|
||||
if dtype_orig is not None:
|
||||
torch.set_default_dtype(dtype_orig)
|
||||
|
||||
if from_tf:
|
||||
if resolved_archive_file.endswith(".index"):
|
||||
# Load from a TensorFlow 1.X checkpoint - provided by original authors
|
||||
@@ -1205,17 +1309,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
|
||||
)
|
||||
raise
|
||||
else:
|
||||
if state_dict is None:
|
||||
try:
|
||||
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
||||
except Exception:
|
||||
raise OSError(
|
||||
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
|
||||
f"at '{resolved_archive_file}'"
|
||||
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
|
||||
)
|
||||
|
||||
elif from_pt:
|
||||
model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model(
|
||||
model, state_dict, pretrained_model_name_or_path, _fast_init=_fast_init
|
||||
)
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
import types
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
|
||||
from ...file_utils import copy_func
|
||||
from ...utils import logging
|
||||
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
|
||||
@@ -367,16 +366,8 @@ class _BaseAutoModelClass:
|
||||
def from_config(cls, config, **kwargs):
|
||||
if type(config) in cls._model_mapping.keys():
|
||||
model_class = _get_model_class(config, cls._model_mapping)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
return model_class._from_config(config, **kwargs)
|
||||
|
||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||
# this immediately partitions the model across all gpus, to avoid the overhead in time
|
||||
# and memory copying it on CPU or each GPU first
|
||||
with deepspeed.zero.Init(config=deepspeed_config()):
|
||||
return model_class(config, **kwargs)
|
||||
else:
|
||||
return model_class(config, **kwargs)
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
||||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||
|
||||
@@ -25,7 +25,7 @@ from typing import Dict, List, Tuple
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import is_torch_available, logging
|
||||
from transformers import AutoModel, is_torch_available, logging
|
||||
from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import (
|
||||
@@ -33,6 +33,7 @@ from transformers.testing_utils import (
|
||||
PASS,
|
||||
USER,
|
||||
CaptureLogger,
|
||||
TestCasePlus,
|
||||
is_staging_test,
|
||||
require_torch,
|
||||
require_torch_multi_gpu,
|
||||
@@ -63,6 +64,7 @@ if is_torch_available():
|
||||
BertModel,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
T5Config,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
|
||||
@@ -1574,7 +1576,7 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModelUtilsTest(unittest.TestCase):
|
||||
class ModelUtilsTest(TestCasePlus):
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
@@ -1607,6 +1609,60 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
BertModel.from_pretrained(TINY_T5)
|
||||
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
|
||||
|
||||
@require_torch
|
||||
def test_model_from_config_torch_dtype(self):
|
||||
# test that the model can be instantiated with dtype of user's choice - as long as it's a
|
||||
# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
|
||||
# model from the config object.
|
||||
|
||||
config = T5Config.from_pretrained(TINY_T5)
|
||||
model = AutoModel.from_config(config)
|
||||
# XXX: isn't supported
|
||||
# model = T5ForConditionalGeneration.from_config(config)
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
|
||||
model = AutoModel.from_config(config, torch_dtype=torch.float16)
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
|
||||
with self.assertRaises(ValueError):
|
||||
model = AutoModel.from_config(config, torch_dtype=torch.int64)
|
||||
|
||||
@require_torch
|
||||
def test_model_from_pretrained_torch_dtype(self):
|
||||
# test that the model can be instantiated with dtype of either
|
||||
# 1. config.torch_dtype setting in the saved model (priority)
|
||||
# 2. via autodiscovery by looking at model weights
|
||||
# so if a model.half() was saved, we want it to be instantiated as such.
|
||||
model_path = self.get_auto_remove_tmp_dir()
|
||||
|
||||
# baseline - we know TINY_T5 is fp32 model
|
||||
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
|
||||
# test the default fp32 save_pretrained => from_pretrained cycle
|
||||
model.save_pretrained(model_path)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path)
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
# test with auto-detection
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
|
||||
# test forced loading in fp16 (even though the weights are in fp32)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
# test fp16 save_pretrained, loaded with auto-detection
|
||||
model = model.half()
|
||||
model.save_pretrained(model_path)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||
self.assertEqual(model.config.torch_dtype, "float16") # tests `config.torch_dtype` saving
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
# test fp16 save_pretrained, loaded with the explicit fp16
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user