From 2f5507580be2f963afc6cd9c1d2340c81d90a2e9 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 10 Feb 2023 09:09:21 -0800 Subject: [PATCH] [from_pretrained] extend `torch_dtype="auto"` to look up `config.torch_dtype` first, expand docs (#21524) * [from_pretrained] expand on torch_dtype entry * fold 4 into 1 * style * support torch_dtype='config' plus tests * style * oops * fold config into auto, fix bug * fix check * better log * better log * clean up --- src/transformers/modeling_utils.py | 57 +++++++++++++++----- src/transformers/models/auto/auto_factory.py | 9 +++- tests/test_modeling_common.py | 32 ++++++++--- 3 files changed, 78 insertions(+), 20 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f331eecddf..15be6bca20 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1904,7 +1904,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. - To test a pull request you made on the Hub, you can pass `revision="refs/pr/". @@ -1932,8 +1931,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is an experimental feature and a subject to change at any moment. torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype - will be automatically derived from the model's weights. + Override the default `torch.dtype` and load the model under a specific `dtype`. The different options + are: + + 1. `torch.float16` or `torch.bfloat16` or `torch.float`: load in a specified + `dtype`, ignoring the model's `config.torch_dtype` if one exists. If not specified + - the model will get loaded in `torch.float` (fp32). + + 2. `"auto"` - A `torch_dtype` entry in the `config.json` file of the model will be + attempted to be used. If this entry isn't found then next check the `dtype` of the first weight in + the checkpoint that's of a floating point type and use that as `dtype`. This will load the model + using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how + the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32. + + + + For some models the `dtype` they were trained in is unknown - you may try to check the model's paper or + reach out to the authors and ask them to add this information to the model's card and to insert the + `torch_dtype` entry in `config.json` on the hub. + + + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the @@ -2098,10 +2116,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix " bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes` or" " pip install bitsandbytes` " ) - if torch_dtype == "auto" or torch_dtype != torch.float16: + if torch_dtype != torch.float16: # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` + logger.warning( + f"Overriding torch_dtype={torch_dtype} with `torch_dtype=torch.float16` due to " + "requirements of `bitsandbytes` to enable model loading in mixed int8. " + "Either pass torch_dtype=torch.float16 or don't pass this argument at all to remove this warning." + ) torch_dtype = torch.float16 - logger.info("Loading the model in mixed int8 - forcing the weights to be casted in float16") + if device_map is None: raise ValueError( "A device map needs to be passed to run convert models into mixed-int8 format. Please run" @@ -2388,17 +2411,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if torch_dtype is not None: if isinstance(torch_dtype, str): if torch_dtype == "auto": - if is_sharded and "dtype" in sharded_metadata: - torch_dtype = sharded_metadata["dtype"] - elif not is_sharded: - torch_dtype = get_state_dict_dtype(state_dict) + if hasattr(config, "torch_dtype") and config.torch_dtype is not None: + torch_dtype = config.torch_dtype + logger.info(f"Will use torch_dtype={torch_dtype} as defined in model's config object") else: - one_state_dict = load_state_dict(resolved_archive_file[0]) - torch_dtype = get_state_dict_dtype(one_state_dict) - del one_state_dict # free CPU memory + if is_sharded and "dtype" in sharded_metadata: + torch_dtype = sharded_metadata["dtype"] + elif not is_sharded: + torch_dtype = get_state_dict_dtype(state_dict) + else: + one_state_dict = load_state_dict(resolved_archive_file[0]) + torch_dtype = get_state_dict_dtype(one_state_dict) + del one_state_dict # free CPU memory + logger.info( + "Since the `torch_dtype` attribute can't be found in model's config object, " + "will use torch_dtype={torch_dtype} as derived from model's weights" + ) else: raise ValueError( - f"`torch_dtype` can be either a `torch.dtype` or `auto`, but received {torch_dtype}" + f'`torch_dtype` can be either `torch.dtype` or `"auto"`, but received {torch_dtype}' ) dtype_orig = cls._set_default_torch_dtype(torch_dtype) diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index c307bf5f99..eb87bb1ff7 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Factory function to build auto-model classes.""" +import copy import importlib from collections import OrderedDict @@ -431,12 +432,18 @@ class _BaseAutoModelClass: ] hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} if not isinstance(config, PretrainedConfig): + kwargs_copy = copy.deepcopy(kwargs) + # ensure not to pollute the config object with torch_dtype="auto" - since it's + # meaningless in the context of the config object - torch.dtype values are acceptable + if kwargs_copy.get("torch_dtype", None) == "auto": + _ = kwargs_copy.pop("torch_dtype") + config, kwargs = AutoConfig.from_pretrained( pretrained_model_name_or_path, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_kwargs, - **kwargs, + **kwargs_copy, ) if hasattr(config, "auto_map") and cls.__name__ in config.auto_map: if not trust_remote_code: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 41217e266e..7cb5c4478c 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2785,7 +2785,6 @@ class ModelUtilsTest(TestCasePlus): for p1, p2 in zip(model.parameters(), new_model.parameters()): self.assertTrue(torch.equal(p1, p2)) - @require_torch def test_model_from_config_torch_dtype(self): # test that the model can be instantiated with dtype of user's choice - as long as it's a # float dtype. To make it happen config.torch_dtype needs to be set before instantiating the @@ -2804,7 +2803,6 @@ class ModelUtilsTest(TestCasePlus): with self.assertRaises(ValueError): model = AutoModel.from_config(config, torch_dtype=torch.int64) - @require_torch def test_model_from_pretrained_torch_dtype(self): # test that the model can be instantiated with dtype of either # 1. explicit from_pretrained's torch_dtype argument @@ -2818,11 +2816,25 @@ class ModelUtilsTest(TestCasePlus): model = T5ForConditionalGeneration.from_pretrained(TINY_T5) self.assertEqual(model.dtype, torch.float32) + def remove_torch_dtype(model_path): + file = f"{model_path}/config.json" + with open(file, "r", encoding="utf-8") as f: + s = json.load(f) + s.pop("torch_dtype") + with open(file, "w", encoding="utf-8") as f: + json.dump(s, f) + # test the default fp32 save_pretrained => from_pretrained cycle model.save_pretrained(model_path) model = T5ForConditionalGeneration.from_pretrained(model_path) self.assertEqual(model.dtype, torch.float32) - # test with auto-detection + # 1. test torch_dtype="auto" via `config.torch_dtype` + model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto") + self.assertEqual(model.dtype, torch.float32) + # 2. test torch_dtype="auto" via auto-derivation + # now remove the torch_dtype entry from config.json and try "auto" again which should + # perform auto-derivation from weights + remove_torch_dtype(model_path) model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto") self.assertEqual(model.dtype, torch.float32) @@ -2833,24 +2845,32 @@ class ModelUtilsTest(TestCasePlus): # test fp16 save_pretrained, loaded with auto-detection model = model.half() model.save_pretrained(model_path) + # 1. test torch_dtype="auto" via `config.torch_dtype` model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto") self.assertEqual(model.config.torch_dtype, torch.float16) self.assertEqual(model.dtype, torch.float16) - # tests `config.torch_dtype` saving with open(f"{model_path}/config.json") as f: config_dict = json.load(f) self.assertEqual(config_dict["torch_dtype"], "float16") + # 2. test torch_dtype="auto" via auto-derivation + # now same with using config info + remove_torch_dtype(model_path) + model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto") + self.assertEqual(model.dtype, torch.float16) # test fp16 save_pretrained, loaded with the explicit fp16 model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16) self.assertEqual(model.dtype, torch.float16) # test AutoModel separately as it goes through a different path - # test auto-detection + # test auto-detection - as currently TINY_T5 doesn't have torch_dtype entry model = AutoModel.from_pretrained(TINY_T5, torch_dtype="auto") + # test that the config object didn't get polluted with torch_dtype="auto" + # there was a bug that after this call we ended up with config.torch_dtype=="auto" + self.assertNotEqual(model.config.torch_dtype, "auto") + # now test the outcome self.assertEqual(model.dtype, torch.float32) - # test forcing an explicit dtype model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16) self.assertEqual(model.dtype, torch.float16)