[models] respect dtype of the model when instantiating it (#12316)
* [models] respect dtype of the model when instantiating it * cleanup * cleanup * rework to handle non-float dtype * fix * switch to fp32 tiny model * improve * use dtype.is_floating_point * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix the doc * recode to use explicit torch_dtype_auto_detect, torch_dtype args * docs and tweaks * docs and tweaks * docs and tweaks * merge 2 args, add docs * fix * fix * better doc * better doc Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -1549,6 +1549,8 @@ Note: If the fp16 weights of the model can't fit onto the memory of a single GPU
|
|||||||
For full details on this method and other related features please refer to `Constructing Massive Models
|
For full details on this method and other related features please refer to `Constructing Massive Models
|
||||||
<https://deepspeed.readthedocs.io/en/latest/zero3.html#constructing-massive-models>`__.
|
<https://deepspeed.readthedocs.io/en/latest/zero3.html#constructing-massive-models>`__.
|
||||||
|
|
||||||
|
Also when loading fp16-pretrained models, you will want to tell ``from_pretrained`` to use
|
||||||
|
``torch_dtype=torch.float16``. For details, please, see :ref:`from_pretrained-torch-dtype`.
|
||||||
|
|
||||||
|
|
||||||
Gathering Parameters
|
Gathering Parameters
|
||||||
|
|||||||
@@ -38,6 +38,37 @@ PreTrainedModel
|
|||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
.. _from_pretrained-torch-dtype:
|
||||||
|
|
||||||
|
Model Instantiation dtype
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Under Pytorch a model normally gets instantiated with ``torch.float32`` format. This can be an issue if one tries to
|
||||||
|
load a model whose weights are in fp16, since it'd require twice as much memory. To overcome this limitation, you can
|
||||||
|
either explicitly pass the desired ``dtype`` using ``torch_dtype`` argument:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained("t5", torch_dtype=torch.float16)
|
||||||
|
|
||||||
|
or, if you want the model to always load in the most optimal memory pattern, you can use the special value ``"auto"``,
|
||||||
|
and then ``dtype`` will be automatically derived from the model's weights:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained("t5", torch_dtype="auto")
|
||||||
|
|
||||||
|
Models instantiated from scratch can also be told which ``dtype`` to use with:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
config = T5Config.from_pretrained("t5")
|
||||||
|
model = AutoModel.from_config(config)
|
||||||
|
|
||||||
|
Due to Pytorch design, this functionality is only available for floating dtypes.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ModuleUtilsMixin
|
ModuleUtilsMixin
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -192,6 +192,12 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
- **tie_word_embeddings** (:obj:`bool`, `optional`, defaults to :obj:`True`) -- Whether the model's input and
|
- **tie_word_embeddings** (:obj:`bool`, `optional`, defaults to :obj:`True`) -- Whether the model's input and
|
||||||
output word embeddings should be tied. Note that this is only relevant if the model has a output word
|
output word embeddings should be tied. Note that this is only relevant if the model has a output word
|
||||||
embedding layer.
|
embedding layer.
|
||||||
|
- **torch_dtype** (:obj:`str`, `optional`) -- The :obj:`dtype` of the weights. This attribute can be used to
|
||||||
|
initialize the model to a non-default ``dtype`` (which is normally ``float32``) and thus allow for optimal
|
||||||
|
storage allocation. For example, if the saved model is ``float16``, ideally we want to load it back using the
|
||||||
|
minimal amount of memory needed to load ``float16`` weights. Since the config object is stored in plain text,
|
||||||
|
this attribute contains just the floating type string without the ``torch.`` prefix. For example, for
|
||||||
|
``torch.float16`` ``torch_dtype`` is the ``"float16"`` string.
|
||||||
|
|
||||||
TensorFlow specific parameters
|
TensorFlow specific parameters
|
||||||
|
|
||||||
@@ -207,6 +213,7 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
|
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
|
||||||
self.output_attentions = kwargs.pop("output_attentions", False)
|
self.output_attentions = kwargs.pop("output_attentions", False)
|
||||||
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
|
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
|
||||||
|
self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
|
||||||
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
||||||
self.pruned_heads = kwargs.pop("pruned_heads", {})
|
self.pruned_heads = kwargs.pop("pruned_heads", {})
|
||||||
self.tie_word_embeddings = kwargs.pop(
|
self.tie_word_embeddings = kwargs.pop(
|
||||||
|
|||||||
@@ -111,6 +111,13 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
|
||||||
raise NotImplementedError(f"init method has to be implemented for {self}")
|
raise NotImplementedError(f"init method has to be implemented for {self}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_config(cls, config, **kwargs):
|
||||||
|
"""
|
||||||
|
All context managers that the model should be initialized under go here.
|
||||||
|
"""
|
||||||
|
return cls(config, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config(self) -> PretrainedConfig:
|
def config(self) -> PretrainedConfig:
|
||||||
return self._config
|
return self._config
|
||||||
|
|||||||
@@ -643,6 +643,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.name_or_path = config.name_or_path
|
self.name_or_path = config.name_or_path
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_config(cls, config, **kwargs):
|
||||||
|
"""
|
||||||
|
All context managers that the model should be initialized under go here.
|
||||||
|
"""
|
||||||
|
return cls(config, **kwargs)
|
||||||
|
|
||||||
@tf.function(
|
@tf.function(
|
||||||
input_signature=[
|
input_signature=[
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, device, dtype, nn
|
from torch import Tensor, device, nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from .activations import get_activation
|
from .activations import get_activation
|
||||||
@@ -201,7 +201,7 @@ class ModuleUtilsMixin:
|
|||||||
return get_parameter_device(self)
|
return get_parameter_device(self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
"""
|
"""
|
||||||
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
||||||
"""
|
"""
|
||||||
@@ -464,6 +464,66 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.name_or_path = config.name_or_path
|
self.name_or_path = config.name_or_path
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_config(cls, config, **kwargs):
|
||||||
|
"""
|
||||||
|
All context managers that the model should be initialized under go here.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
torch_dtype (:obj:`torch.dtype`, `optional`):
|
||||||
|
Override the default ``torch.dtype`` and load the model under this dtype.
|
||||||
|
"""
|
||||||
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
|
|
||||||
|
# override default dtype if needed
|
||||||
|
dtype_orig = None
|
||||||
|
if torch_dtype is not None:
|
||||||
|
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||||
|
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
|
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||||
|
# this immediately partitions the model across all gpus, to avoid the overhead in time
|
||||||
|
# and memory copying it on CPU or each GPU first
|
||||||
|
with deepspeed.zero.Init(config=deepspeed_config()):
|
||||||
|
model = cls(config, **kwargs)
|
||||||
|
else:
|
||||||
|
model = cls(config, **kwargs)
|
||||||
|
|
||||||
|
# restore default dtype if it was modified
|
||||||
|
if dtype_orig is not None:
|
||||||
|
torch.set_default_dtype(dtype_orig)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
|
||||||
|
"""
|
||||||
|
Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
|
||||||
|
under specific dtype.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype (:obj:`torch.dtype`):
|
||||||
|
a floating dtype to set to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`torch.dtype`: the original ``dtype`` that can be used to restore ``torch.set_default_dtype(dtype)``
|
||||||
|
if it was modified. If it wasn't, returns :obj:`None`.
|
||||||
|
|
||||||
|
Note ``set_default_dtype`` currently only works with floating-point types and asserts if for example,
|
||||||
|
``torch.int64`` is passed. So if a non-float ``dtype`` is passed this functions will throw an exception.
|
||||||
|
"""
|
||||||
|
if not dtype.is_floating_point:
|
||||||
|
raise ValueError(
|
||||||
|
f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
|
||||||
|
dtype_orig = torch.get_default_dtype()
|
||||||
|
torch.set_default_dtype(dtype)
|
||||||
|
return dtype_orig
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def base_model(self) -> nn.Module:
|
def base_model(self) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
@@ -876,6 +936,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# Only save the model itself if we are using distributed training
|
# Only save the model itself if we are using distributed training
|
||||||
model_to_save = unwrap_model(self)
|
model_to_save = unwrap_model(self)
|
||||||
|
|
||||||
|
# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
|
||||||
|
# we currently don't use this setting automatically, but may start to use with v5
|
||||||
|
dtype = get_parameter_dtype(model_to_save)
|
||||||
|
model_to_save.config.torch_dtype = str(dtype).split(".")[1]
|
||||||
|
|
||||||
# Attach architecture to the config
|
# Attach architecture to the config
|
||||||
model_to_save.config.architectures = [model_to_save.__class__.__name__]
|
model_to_save.config.architectures = [model_to_save.__class__.__name__]
|
||||||
|
|
||||||
@@ -993,6 +1058,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
Please refer to the mirror site for more information.
|
Please refer to the mirror site for more information.
|
||||||
_fast_init(:obj:`bool`, `optional`, defaults to `:obj:`True`):
|
_fast_init(:obj:`bool`, `optional`, defaults to `:obj:`True`):
|
||||||
Whether or not to disable fast initialization.
|
Whether or not to disable fast initialization.
|
||||||
|
torch_dtype (:obj:`str` or :obj:`torch.dtype`, `optional`):
|
||||||
|
Override the default ``torch.dtype`` and load the model under this dtype. If ``"auto"`` is passed the
|
||||||
|
dtype will be automatically derived from the model's weights.
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
|
|
||||||
@@ -1058,6 +1126,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||||
from_auto_class = kwargs.pop("_from_auto", False)
|
from_auto_class = kwargs.pop("_from_auto", False)
|
||||||
_fast_init = kwargs.pop("_fast_init", True)
|
_fast_init = kwargs.pop("_fast_init", True)
|
||||||
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
|
|
||||||
|
from_pt = not (from_tf | from_flax)
|
||||||
|
|
||||||
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
|
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
|
||||||
if from_pipeline is not None:
|
if from_pipeline is not None:
|
||||||
@@ -1162,6 +1233,34 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
else:
|
else:
|
||||||
resolved_archive_file = None
|
resolved_archive_file = None
|
||||||
|
|
||||||
|
# load pt weights early so that we know which dtype to init the model under
|
||||||
|
if from_pt:
|
||||||
|
if state_dict is None:
|
||||||
|
try:
|
||||||
|
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
||||||
|
except Exception:
|
||||||
|
raise OSError(
|
||||||
|
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
|
||||||
|
f"at '{resolved_archive_file}'"
|
||||||
|
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
|
||||||
|
)
|
||||||
|
|
||||||
|
# set dtype to instantiate the model under:
|
||||||
|
# 1. If torch_dtype is not None, we use that dtype
|
||||||
|
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
||||||
|
# weights entry - we assume all weights are of the same dtype
|
||||||
|
# we also may have config.torch_dtype available, but we won't rely on it till v5
|
||||||
|
dtype_orig = None
|
||||||
|
if torch_dtype is not None:
|
||||||
|
if isinstance(torch_dtype, str):
|
||||||
|
if torch_dtype == "auto":
|
||||||
|
torch_dtype = next(iter(state_dict.values())).dtype
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"`torch_dtype` can be either a `torch.dtype` or `auto`, but received {torch_dtype}"
|
||||||
|
)
|
||||||
|
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||||
|
|
||||||
config.name_or_path = pretrained_model_name_or_path
|
config.name_or_path = pretrained_model_name_or_path
|
||||||
|
|
||||||
# Instantiate model.
|
# Instantiate model.
|
||||||
@@ -1178,6 +1277,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
with no_init_weights(_enable=_fast_init):
|
with no_init_weights(_enable=_fast_init):
|
||||||
model = cls(config, *model_args, **model_kwargs)
|
model = cls(config, *model_args, **model_kwargs)
|
||||||
|
|
||||||
|
if from_pt:
|
||||||
|
# restore default dtype
|
||||||
|
if dtype_orig is not None:
|
||||||
|
torch.set_default_dtype(dtype_orig)
|
||||||
|
|
||||||
if from_tf:
|
if from_tf:
|
||||||
if resolved_archive_file.endswith(".index"):
|
if resolved_archive_file.endswith(".index"):
|
||||||
# Load from a TensorFlow 1.X checkpoint - provided by original authors
|
# Load from a TensorFlow 1.X checkpoint - provided by original authors
|
||||||
@@ -1205,17 +1309,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
|
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
else:
|
elif from_pt:
|
||||||
if state_dict is None:
|
|
||||||
try:
|
|
||||||
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
|
||||||
except Exception:
|
|
||||||
raise OSError(
|
|
||||||
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
|
|
||||||
f"at '{resolved_archive_file}'"
|
|
||||||
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
|
|
||||||
)
|
|
||||||
|
|
||||||
model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model(
|
model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model(
|
||||||
model, state_dict, pretrained_model_name_or_path, _fast_init=_fast_init
|
model, state_dict, pretrained_model_name_or_path, _fast_init=_fast_init
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,7 +17,6 @@
|
|||||||
import types
|
import types
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
|
|
||||||
from ...file_utils import copy_func
|
from ...file_utils import copy_func
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
|
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
|
||||||
@@ -367,16 +366,8 @@ class _BaseAutoModelClass:
|
|||||||
def from_config(cls, config, **kwargs):
|
def from_config(cls, config, **kwargs):
|
||||||
if type(config) in cls._model_mapping.keys():
|
if type(config) in cls._model_mapping.keys():
|
||||||
model_class = _get_model_class(config, cls._model_mapping)
|
model_class = _get_model_class(config, cls._model_mapping)
|
||||||
if is_deepspeed_zero3_enabled():
|
return model_class._from_config(config, **kwargs)
|
||||||
import deepspeed
|
|
||||||
|
|
||||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
|
||||||
# this immediately partitions the model across all gpus, to avoid the overhead in time
|
|
||||||
# and memory copying it on CPU or each GPU first
|
|
||||||
with deepspeed.zero.Init(config=deepspeed_config()):
|
|
||||||
return model_class(config, **kwargs)
|
|
||||||
else:
|
|
||||||
return model_class(config, **kwargs)
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
||||||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from typing import Dict, List, Tuple
|
|||||||
|
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers import is_torch_available, logging
|
from transformers import AutoModel, is_torch_available, logging
|
||||||
from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available
|
from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available
|
||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
@@ -33,6 +33,7 @@ from transformers.testing_utils import (
|
|||||||
PASS,
|
PASS,
|
||||||
USER,
|
USER,
|
||||||
CaptureLogger,
|
CaptureLogger,
|
||||||
|
TestCasePlus,
|
||||||
is_staging_test,
|
is_staging_test,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
@@ -63,6 +64,7 @@ if is_torch_available():
|
|||||||
BertModel,
|
BertModel,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
|
T5Config,
|
||||||
T5ForConditionalGeneration,
|
T5ForConditionalGeneration,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1574,7 +1576,7 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ModelUtilsTest(unittest.TestCase):
|
class ModelUtilsTest(TestCasePlus):
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
@@ -1607,6 +1609,60 @@ class ModelUtilsTest(unittest.TestCase):
|
|||||||
BertModel.from_pretrained(TINY_T5)
|
BertModel.from_pretrained(TINY_T5)
|
||||||
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
|
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_model_from_config_torch_dtype(self):
|
||||||
|
# test that the model can be instantiated with dtype of user's choice - as long as it's a
|
||||||
|
# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
|
||||||
|
# model from the config object.
|
||||||
|
|
||||||
|
config = T5Config.from_pretrained(TINY_T5)
|
||||||
|
model = AutoModel.from_config(config)
|
||||||
|
# XXX: isn't supported
|
||||||
|
# model = T5ForConditionalGeneration.from_config(config)
|
||||||
|
self.assertEqual(model.dtype, torch.float32)
|
||||||
|
|
||||||
|
model = AutoModel.from_config(config, torch_dtype=torch.float16)
|
||||||
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
|
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
model = AutoModel.from_config(config, torch_dtype=torch.int64)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_model_from_pretrained_torch_dtype(self):
|
||||||
|
# test that the model can be instantiated with dtype of either
|
||||||
|
# 1. config.torch_dtype setting in the saved model (priority)
|
||||||
|
# 2. via autodiscovery by looking at model weights
|
||||||
|
# so if a model.half() was saved, we want it to be instantiated as such.
|
||||||
|
model_path = self.get_auto_remove_tmp_dir()
|
||||||
|
|
||||||
|
# baseline - we know TINY_T5 is fp32 model
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
|
||||||
|
self.assertEqual(model.dtype, torch.float32)
|
||||||
|
|
||||||
|
# test the default fp32 save_pretrained => from_pretrained cycle
|
||||||
|
model.save_pretrained(model_path)
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(model_path)
|
||||||
|
self.assertEqual(model.dtype, torch.float32)
|
||||||
|
# test with auto-detection
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||||
|
self.assertEqual(model.dtype, torch.float32)
|
||||||
|
|
||||||
|
# test forced loading in fp16 (even though the weights are in fp32)
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||||
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
|
# test fp16 save_pretrained, loaded with auto-detection
|
||||||
|
model = model.half()
|
||||||
|
model.save_pretrained(model_path)
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||||
|
self.assertEqual(model.config.torch_dtype, "float16") # tests `config.torch_dtype` saving
|
||||||
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
|
# test fp16 save_pretrained, loaded with the explicit fp16
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||||
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
Reference in New Issue
Block a user