From da971b2271728addaac5b8fd3fb68a2894e71995 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 21 Sep 2023 19:00:03 +0900 Subject: [PATCH] Keep relevant weights in fp32 when `model._keep_in_fp32_modules` is set even when `accelerate` is not installed (#26225) * fix bug where weight would not be kept in fp32 * nit * address review comments * fix test --- src/transformers/modeling_utils.py | 27 ++++++++++--------------- tests/models/t5/test_modeling_t5.py | 31 +++++++++++++++++++++-------- 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1432c3b78a..1fd1570e8c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2950,26 +2950,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix dtype_orig = cls._set_default_torch_dtype(torch_dtype) # Check if `_keep_in_fp32_modules` is not None - use_keep_in_fp32_modules = ( - (cls._keep_in_fp32_modules is not None) - and is_accelerate_available() - and (torch_dtype == torch.float16 or load_in_4bit or load_in_8bit) + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( + torch_dtype == torch.float16 or load_in_4bit or load_in_8bit ) - if ( - (cls._keep_in_fp32_modules is not None) - and not is_accelerate_available() - and torch_dtype == torch.float16 - ): - logger.warning( - "For stability purposes, it is recommended to have accelerate installed when using this model in" - " torch.float16, please install it with `pip install accelerate`" - ) if is_sharded: loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] else: loaded_state_dict_keys = list(state_dict.keys()) - if low_cpu_mem_usage or use_keep_in_fp32_modules: + if low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available()): + # In case some weights need to be kept in float32 and accelerate is not installed, + # we later on want to take the path where state_dict is not None, that is the one + # that do not require accelerate. state_dict = None config.name_or_path = pretrained_model_name_or_path @@ -2990,7 +2982,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Check first if we are `from_pt` if use_keep_in_fp32_modules: - low_cpu_mem_usage = True + if is_accelerate_available(): + low_cpu_mem_usage = True keep_in_fp32_modules = model._keep_in_fp32_modules else: keep_in_fp32_modules = [] @@ -3465,7 +3458,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if keep_in_fp32_modules is not None: for name, param in model.named_parameters(): if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules): - param = param.to(torch.float32) + # param = param.to(torch.float32) does not work here as only in the local scope. + param.data = param.data.to(torch.float32) # Make sure we are able to load base models as well as derived models (with heads) start_prefix = "" @@ -3592,7 +3586,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix remove_prefix_from_model, ignore_mismatched_sizes, ) - if low_cpu_mem_usage: if not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0(): new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index cae891ef8b..c94bfc1f11 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -1046,15 +1046,30 @@ class T5ModelFp16Tests(unittest.TestCase): r""" A test to check whether the argument `keep_in_fp32_modules` correctly does its job """ - # Load without using `accelerate` - model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) + orig_import = __import__ + accelerate_mock = unittest.mock.Mock() - # Load without in bf16 - model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) - self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) + # mock import of accelerate + def import_accelerate_mock(name, *args, **kwargs): + if name == "accelerate": + if accelerate_available: + return accelerate_mock + else: + raise ImportError + return orig_import(name, *args, **kwargs) + + # Load without using `accelerate` + with unittest.mock.patch("builtins.__import__", side_effect=import_accelerate_mock): + accelerate_available = False + + model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16) + + # Load without in bf16 + model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16) + self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16) # Load using `accelerate` in bf16 model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16, device_map="auto")