From 5c6eca71a983bae2589eed01e5c04fcf88ba5690 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 24 Aug 2021 02:43:41 -0700 Subject: [PATCH] fix `AutoModel.from_pretrained(..., torch_dtype=...)` (#13209) * fix AutoModel.from_pretrained(..., torch_dtype=...) * fix to_diff_dict * add better test * torch is not always available when a model has self.torch_dtype --- src/transformers/configuration_utils.py | 28 ++++++++++++++++++++++++- tests/test_modeling_common.py | 22 ++++++++++++++++--- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 8b6a0bc4ad..9649a176e8 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -30,6 +30,7 @@ from .file_utils import ( hf_bucket_url, is_offline_mode, is_remote_url, + is_torch_available, ) from .utils import logging @@ -207,6 +208,9 @@ class PretrainedConfig(PushToHubMixin): this attribute contains just the floating type string without the ``torch.`` prefix. For example, for ``torch.float16`` ``torch_dtype`` is the ``"float16"`` string. + This attribute is currently not being used during model loading time, but this may change in the future + versions. But we can already start preparing for the future by saving the dtype with save_pretrained. + TensorFlow specific parameters - **use_bfloat16** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should use @@ -270,6 +274,14 @@ class PretrainedConfig(PushToHubMixin): else: self.num_labels = kwargs.pop("num_labels", 2) + if self.torch_dtype is not None and isinstance(self.torch_dtype, str): + # we will start using self.torch_dtype in v5, but to be consistent with + # from_pretrained's torch_dtype arg convert it to an actual torch.dtype object + if is_torch_available(): + import torch + + self.torch_dtype = getattr(torch, self.torch_dtype) + # Tokenizer arguments TODO: eventually tokenizer and models should share the same config self.tokenizer_class = kwargs.pop("tokenizer_class", None) self.prefix = kwargs.pop("prefix", None) @@ -574,7 +586,8 @@ class PretrainedConfig(PushToHubMixin): for key, value in kwargs.items(): if hasattr(config, key): setattr(config, key, value) - to_remove.append(key) + if key != "torch_dtype": + to_remove.append(key) for key in to_remove: kwargs.pop(key, None) @@ -640,6 +653,8 @@ class PretrainedConfig(PushToHubMixin): ): serializable_config_dict[key] = value + self.dict_torch_dtype_to_str(serializable_config_dict) + return serializable_config_dict def to_dict(self) -> Dict[str, Any]: @@ -656,6 +671,8 @@ class PretrainedConfig(PushToHubMixin): # Transformers version when serializing the model output["transformers_version"] = __version__ + self.dict_torch_dtype_to_str(output) + return output def to_json_string(self, use_diff: bool = True) -> str: @@ -738,6 +755,15 @@ class PretrainedConfig(PushToHubMixin): setattr(self, k, v) + def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None: + """ + Checks whether the passed dictionary has a `torch_dtype` key and if it's not None, converts torch.dtype to a + string of just the type. For example, :obj:`torch.float32` get converted into `"float32"` string, which can + then be stored in the json format. + """ + if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str): + d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1] + PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub) PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format( diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index d109bbf6a5..51246b628b 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -16,6 +16,7 @@ import copy import gc import inspect +import json import os.path import random import tempfile @@ -1663,9 +1664,11 @@ class ModelUtilsTest(TestCasePlus): @require_torch def test_model_from_pretrained_torch_dtype(self): # test that the model can be instantiated with dtype of either - # 1. config.torch_dtype setting in the saved model (priority) - # 2. via autodiscovery by looking at model weights + # 1. explicit from_pretrained's torch_dtype argument + # 2. via autodiscovery by looking at model weights (torch_dtype="auto") # so if a model.half() was saved, we want it to be instantiated as such. + # + # test an explicit model class, but also AutoModel separately as the latter goes through a different code path model_path = self.get_auto_remove_tmp_dir() # baseline - we know TINY_T5 is fp32 model @@ -1688,13 +1691,26 @@ class ModelUtilsTest(TestCasePlus): model = model.half() model.save_pretrained(model_path) model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto") - self.assertEqual(model.config.torch_dtype, "float16") # tests `config.torch_dtype` saving + self.assertEqual(model.config.torch_dtype, torch.float16) self.assertEqual(model.dtype, torch.float16) + # tests `config.torch_dtype` saving + with open(f"{model_path}/config.json") as f: + config_dict = json.load(f) + self.assertEqual(config_dict["torch_dtype"], "float16") + # test fp16 save_pretrained, loaded with the explicit fp16 model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16) self.assertEqual(model.dtype, torch.float16) + # test AutoModel separately as it goes through a different path + # test auto-detection + model = AutoModel.from_pretrained(TINY_T5, torch_dtype="auto") + self.assertEqual(model.dtype, torch.float32) + # test forcing an explicit dtype + model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16) + self.assertEqual(model.dtype, torch.float16) + @require_torch @is_staging_test