From 7682e97702e4317231b3afe92359de384dba1e20 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 28 Jun 2021 20:11:21 -0700 Subject: [PATCH] [models] respect dtype of the model when instantiating it (#12316) * [models] respect dtype of the model when instantiating it * cleanup * cleanup * rework to handle non-float dtype * fix * switch to fp32 tiny model * improve * use dtype.is_floating_point * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix the doc * recode to use explicit torch_dtype_auto_detect, torch_dtype args * docs and tweaks * docs and tweaks * docs and tweaks * merge 2 args, add docs * fix * fix * better doc * better doc Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/main_classes/deepspeed.rst | 2 + docs/source/main_classes/model.rst | 33 ++++- src/transformers/configuration_utils.py | 7 ++ src/transformers/modeling_flax_utils.py | 7 ++ src/transformers/modeling_tf_utils.py | 7 ++ src/transformers/modeling_utils.py | 120 +++++++++++++++++-- src/transformers/models/auto/auto_factory.py | 11 +- tests/test_modeling_common.py | 60 +++++++++- 8 files changed, 221 insertions(+), 26 deletions(-) diff --git a/docs/source/main_classes/deepspeed.rst b/docs/source/main_classes/deepspeed.rst index aa47bf284b..619dfd4b8a 100644 --- a/docs/source/main_classes/deepspeed.rst +++ b/docs/source/main_classes/deepspeed.rst @@ -1549,6 +1549,8 @@ Note: If the fp16 weights of the model can't fit onto the memory of a single GPU For full details on this method and other related features please refer to `Constructing Massive Models `__. +Also when loading fp16-pretrained models, you will want to tell ``from_pretrained`` to use +``torch_dtype=torch.float16``. For details, please, see :ref:`from_pretrained-torch-dtype`. Gathering Parameters diff --git a/docs/source/main_classes/model.rst b/docs/source/main_classes/model.rst index e311a36eaa..d3bb0e2326 100644 --- a/docs/source/main_classes/model.rst +++ b/docs/source/main_classes/model.rst @@ -1,4 +1,4 @@ -.. +.. Copyright 2020 The HuggingFace Team. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with @@ -38,6 +38,37 @@ PreTrainedModel :members: +.. _from_pretrained-torch-dtype: + +Model Instantiation dtype +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Under Pytorch a model normally gets instantiated with ``torch.float32`` format. This can be an issue if one tries to +load a model whose weights are in fp16, since it'd require twice as much memory. To overcome this limitation, you can +either explicitly pass the desired ``dtype`` using ``torch_dtype`` argument: + +.. code-block:: python + + model = T5ForConditionalGeneration.from_pretrained("t5", torch_dtype=torch.float16) + +or, if you want the model to always load in the most optimal memory pattern, you can use the special value ``"auto"``, +and then ``dtype`` will be automatically derived from the model's weights: + +.. code-block:: python + + model = T5ForConditionalGeneration.from_pretrained("t5", torch_dtype="auto") + +Models instantiated from scratch can also be told which ``dtype`` to use with: + +.. code-block:: python + + config = T5Config.from_pretrained("t5") + model = AutoModel.from_config(config) + +Due to Pytorch design, this functionality is only available for floating dtypes. + + + ModuleUtilsMixin ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 574d6daa4e..5490b5d611 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -192,6 +192,12 @@ class PretrainedConfig(PushToHubMixin): - **tie_word_embeddings** (:obj:`bool`, `optional`, defaults to :obj:`True`) -- Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the model has a output word embedding layer. + - **torch_dtype** (:obj:`str`, `optional`) -- The :obj:`dtype` of the weights. This attribute can be used to + initialize the model to a non-default ``dtype`` (which is normally ``float32``) and thus allow for optimal + storage allocation. For example, if the saved model is ``float16``, ideally we want to load it back using the + minimal amount of memory needed to load ``float16`` weights. Since the config object is stored in plain text, + this attribute contains just the floating type string without the ``torch.`` prefix. For example, for + ``torch.float16`` ``torch_dtype`` is the ``"float16"`` string. TensorFlow specific parameters @@ -207,6 +213,7 @@ class PretrainedConfig(PushToHubMixin): self.output_hidden_states = kwargs.pop("output_hidden_states", False) self.output_attentions = kwargs.pop("output_attentions", False) self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models + self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models self.use_bfloat16 = kwargs.pop("use_bfloat16", False) self.pruned_heads = kwargs.pop("pruned_heads", {}) self.tie_word_embeddings = kwargs.pop( diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 8d5006791a..00ccdfcfb7 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -111,6 +111,13 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict: raise NotImplementedError(f"init method has to be implemented for {self}") + @classmethod + def _from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + """ + return cls(config, **kwargs) + @property def config(self) -> PretrainedConfig: return self._config diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index e490dfaa55..b2587353b6 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -643,6 +643,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu self.config = config self.name_or_path = config.name_or_path + @classmethod + def _from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + """ + return cls(config, **kwargs) + @tf.function( input_signature=[ { diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3da1ea4484..86b1003d85 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -23,7 +23,7 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import torch -from torch import Tensor, device, dtype, nn +from torch import Tensor, device, nn from torch.nn import CrossEntropyLoss from .activations import get_activation @@ -201,7 +201,7 @@ class ModuleUtilsMixin: return get_parameter_device(self) @property - def dtype(self) -> dtype: + def dtype(self) -> torch.dtype: """ :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). """ @@ -464,6 +464,66 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix self.config = config self.name_or_path = config.name_or_path + @classmethod + def _from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + + Args: + torch_dtype (:obj:`torch.dtype`, `optional`): + Override the default ``torch.dtype`` and load the model under this dtype. + """ + torch_dtype = kwargs.pop("torch_dtype", None) + + # override default dtype if needed + dtype_orig = None + if torch_dtype is not None: + dtype_orig = cls._set_default_torch_dtype(torch_dtype) + + if is_deepspeed_zero3_enabled(): + import deepspeed + + logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") + # this immediately partitions the model across all gpus, to avoid the overhead in time + # and memory copying it on CPU or each GPU first + with deepspeed.zero.Init(config=deepspeed_config()): + model = cls(config, **kwargs) + else: + model = cls(config, **kwargs) + + # restore default dtype if it was modified + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + return model + + @classmethod + def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype: + """ + Change the default dtype and return the previous one. This is needed when wanting to instantiate the model + under specific dtype. + + Args: + dtype (:obj:`torch.dtype`): + a floating dtype to set to. + + Returns: + :obj:`torch.dtype`: the original ``dtype`` that can be used to restore ``torch.set_default_dtype(dtype)`` + if it was modified. If it wasn't, returns :obj:`None`. + + Note ``set_default_dtype`` currently only works with floating-point types and asserts if for example, + ``torch.int64`` is passed. So if a non-float ``dtype`` is passed this functions will throw an exception. + """ + if not dtype.is_floating_point: + raise ValueError( + f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype" + ) + + logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.") + dtype_orig = torch.get_default_dtype() + torch.set_default_dtype(dtype) + return dtype_orig + @property def base_model(self) -> nn.Module: """ @@ -876,6 +936,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Only save the model itself if we are using distributed training model_to_save = unwrap_model(self) + # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" + # we currently don't use this setting automatically, but may start to use with v5 + dtype = get_parameter_dtype(model_to_save) + model_to_save.config.torch_dtype = str(dtype).split(".")[1] + # Attach architecture to the config model_to_save.config.architectures = [model_to_save.__class__.__name__] @@ -993,6 +1058,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix Please refer to the mirror site for more information. _fast_init(:obj:`bool`, `optional`, defaults to `:obj:`True`): Whether or not to disable fast initialization. + torch_dtype (:obj:`str` or :obj:`torch.dtype`, `optional`): + Override the default ``torch.dtype`` and load the model under this dtype. If ``"auto"`` is passed the + dtype will be automatically derived from the model's weights. .. warning:: @@ -1058,6 +1126,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) _fast_init = kwargs.pop("_fast_init", True) + torch_dtype = kwargs.pop("torch_dtype", None) + + from_pt = not (from_tf | from_flax) user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} if from_pipeline is not None: @@ -1162,6 +1233,34 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix else: resolved_archive_file = None + # load pt weights early so that we know which dtype to init the model under + if from_pt: + if state_dict is None: + try: + state_dict = torch.load(resolved_archive_file, map_location="cpu") + except Exception: + raise OSError( + f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' " + f"at '{resolved_archive_file}'" + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. " + ) + + # set dtype to instantiate the model under: + # 1. If torch_dtype is not None, we use that dtype + # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first + # weights entry - we assume all weights are of the same dtype + # we also may have config.torch_dtype available, but we won't rely on it till v5 + dtype_orig = None + if torch_dtype is not None: + if isinstance(torch_dtype, str): + if torch_dtype == "auto": + torch_dtype = next(iter(state_dict.values())).dtype + else: + raise ValueError( + f"`torch_dtype` can be either a `torch.dtype` or `auto`, but received {torch_dtype}" + ) + dtype_orig = cls._set_default_torch_dtype(torch_dtype) + config.name_or_path = pretrained_model_name_or_path # Instantiate model. @@ -1178,6 +1277,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix with no_init_weights(_enable=_fast_init): model = cls(config, *model_args, **model_kwargs) + if from_pt: + # restore default dtype + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + if from_tf: if resolved_archive_file.endswith(".index"): # Load from a TensorFlow 1.X checkpoint - provided by original authors @@ -1205,17 +1309,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix "https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions." ) raise - else: - if state_dict is None: - try: - state_dict = torch.load(resolved_archive_file, map_location="cpu") - except Exception: - raise OSError( - f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' " - f"at '{resolved_archive_file}'" - "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. " - ) - + elif from_pt: model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model( model, state_dict, pretrained_model_name_or_path, _fast_init=_fast_init ) diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 0d82184be5..77788d063f 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -17,7 +17,6 @@ import types from ...configuration_utils import PretrainedConfig -from ...deepspeed import deepspeed_config, is_deepspeed_zero3_enabled from ...file_utils import copy_func from ...utils import logging from .configuration_auto import AutoConfig, replace_list_option_in_docstrings @@ -367,16 +366,8 @@ class _BaseAutoModelClass: def from_config(cls, config, **kwargs): if type(config) in cls._model_mapping.keys(): model_class = _get_model_class(config, cls._model_mapping) - if is_deepspeed_zero3_enabled(): - import deepspeed + return model_class._from_config(config, **kwargs) - logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") - # this immediately partitions the model across all gpus, to avoid the overhead in time - # and memory copying it on CPU or each GPU first - with deepspeed.zero.Init(config=deepspeed_config()): - return model_class(config, **kwargs) - else: - return model_class(config, **kwargs) raise ValueError( f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1af00b909d..6c2eebb9ac 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -25,7 +25,7 @@ from typing import Dict, List, Tuple from huggingface_hub import HfApi from requests.exceptions import HTTPError -from transformers import is_torch_available, logging +from transformers import AutoModel, is_torch_available, logging from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available from transformers.models.auto import get_values from transformers.testing_utils import ( @@ -33,6 +33,7 @@ from transformers.testing_utils import ( PASS, USER, CaptureLogger, + TestCasePlus, is_staging_test, require_torch, require_torch_multi_gpu, @@ -63,6 +64,7 @@ if is_torch_available(): BertModel, PretrainedConfig, PreTrainedModel, + T5Config, T5ForConditionalGeneration, ) @@ -1574,7 +1576,7 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None): @require_torch -class ModelUtilsTest(unittest.TestCase): +class ModelUtilsTest(TestCasePlus): @slow def test_model_from_pretrained(self): for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: @@ -1607,6 +1609,60 @@ class ModelUtilsTest(unittest.TestCase): BertModel.from_pretrained(TINY_T5) self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out) + @require_torch + def test_model_from_config_torch_dtype(self): + # test that the model can be instantiated with dtype of user's choice - as long as it's a + # float dtype. To make it happen config.torch_dtype needs to be set before instantiating the + # model from the config object. + + config = T5Config.from_pretrained(TINY_T5) + model = AutoModel.from_config(config) + # XXX: isn't supported + # model = T5ForConditionalGeneration.from_config(config) + self.assertEqual(model.dtype, torch.float32) + + model = AutoModel.from_config(config, torch_dtype=torch.float16) + self.assertEqual(model.dtype, torch.float16) + + # torch.set_default_dtype() supports only float dtypes, so will fail with non-float type + with self.assertRaises(ValueError): + model = AutoModel.from_config(config, torch_dtype=torch.int64) + + @require_torch + def test_model_from_pretrained_torch_dtype(self): + # test that the model can be instantiated with dtype of either + # 1. config.torch_dtype setting in the saved model (priority) + # 2. via autodiscovery by looking at model weights + # so if a model.half() was saved, we want it to be instantiated as such. + model_path = self.get_auto_remove_tmp_dir() + + # baseline - we know TINY_T5 is fp32 model + model = T5ForConditionalGeneration.from_pretrained(TINY_T5) + self.assertEqual(model.dtype, torch.float32) + + # test the default fp32 save_pretrained => from_pretrained cycle + model.save_pretrained(model_path) + model = T5ForConditionalGeneration.from_pretrained(model_path) + self.assertEqual(model.dtype, torch.float32) + # test with auto-detection + model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto") + self.assertEqual(model.dtype, torch.float32) + + # test forced loading in fp16 (even though the weights are in fp32) + model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16) + self.assertEqual(model.dtype, torch.float16) + + # test fp16 save_pretrained, loaded with auto-detection + model = model.half() + model.save_pretrained(model_path) + model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto") + self.assertEqual(model.config.torch_dtype, "float16") # tests `config.torch_dtype` saving + self.assertEqual(model.dtype, torch.float16) + + # test fp16 save_pretrained, loaded with the explicit fp16 + model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16) + self.assertEqual(model.dtype, torch.float16) + @require_torch @is_staging_test