From 02ecdcfc0f7d81e90a9c8e7f9e6d636123a84254 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Thu, 26 Jun 2025 15:55:28 +0200 Subject: [PATCH] add _keep_in_fp32_modules_strict (#39058) * add _keep_in_fp32_modules_strict * complete test --- src/transformers/modeling_utils.py | 48 ++++++++---- .../modeling_kyutai_speech_to_text.py | 2 +- .../modular_kyutai_speech_to_text.py | 2 +- .../test_modeling_kyutai_speech_to_text.py | 76 +++++++++++++++++++ 4 files changed, 111 insertions(+), 17 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ea2bd32aa3..515fb6d381 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1937,7 +1937,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi _auto_class = None _no_split_modules = None _skip_keys_device_placement = None + _keep_in_fp32_modules = None + # the _keep_in_fp32_modules will avoid casting to anything other than float32, except bfloat16 + # to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag + _keep_in_fp32_modules_strict = None # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings. @@ -2049,6 +2053,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute # when a different component (e.g. language_model) is used. self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules) + self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict) self._no_split_modules = self._no_split_modules or [] @@ -2061,7 +2066,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi self._backward_compatibility_gradient_checkpointing() # Make sure the modules correctly exist if the flag is active - if self._keep_in_fp32_modules is not None: + if self._keep_in_fp32_modules is not None or self._keep_in_fp32_modules_strict is not None: all_parameters = {name for name, _ in self.named_parameters() if len(name) > 0} unique_module_names = set() # Get all unique module names in the module graph, without the prefixes @@ -2070,12 +2075,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi [name for name in param.split(".") if not name.isnumeric() and name not in ["weight", "bias"]] ) # Check that every module in the keep_in_fp32 list is part of the module graph - for module in self._keep_in_fp32_modules: - if module not in unique_module_names: - raise ValueError( - f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in" - f" {self.__class__.__name__}" - ) + if self._keep_in_fp32_modules is not None: + for module in self._keep_in_fp32_modules: + if module not in unique_module_names: + raise ValueError( + f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in" + f" {self.__class__.__name__}" + ) + + if self._keep_in_fp32_modules_strict is not None: + for module in self._keep_in_fp32_modules_strict: + if module not in unique_module_names: + raise ValueError( + f"{module} was specified in the `_keep_in_fp32_modules_strict` list, but is not part of the modules in" + f" {self.__class__.__name__}" + ) # If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None @@ -4757,20 +4771,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi config = model.config # Find fp32 modules if needed - keep_in_fp32_regex = None + keep_in_fp32_modules = [] # The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced # in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing # step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details. - # Update: to extend _keep_in_fp32_modules flag feature, it can also be used to force modules that should stay in fp32 if model._keep_in_fp32_modules is not None and ( - torch_dtype == torch.float16 - or torch_dtype == torch.bfloat16 - or getattr(hf_quantizer, "use_keep_in_fp32_modules", False) + torch_dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False) ): + keep_in_fp32_modules.extend(model._keep_in_fp32_modules) + + if model._keep_in_fp32_modules_strict is not None and ( + torch_dtype == torch.float16 or torch_dtype == torch.bfloat16 + ): + keep_in_fp32_modules.extend(model._keep_in_fp32_modules_strict) + + keep_in_fp32_regex = None + if keep_in_fp32_modules: # We need to match exact layers, so we add either `.` on each side, or start/end of string - keep_in_fp32_regex = re.compile( - "|".join([rf"((^|\.){module}($|\.))" for module in model._keep_in_fp32_modules]) - ) + keep_in_fp32_regex = re.compile("|".join([rf"((^|\.){module}($|\.))" for module in keep_in_fp32_modules])) if hf_quantizer is not None: hf_quantizer.preprocess_model( diff --git a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py index 67c4dac4cc..5abc0bd3fc 100644 --- a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py @@ -1103,7 +1103,7 @@ class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedMod _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - _keep_in_fp32_modules = ["codec_model"] + _keep_in_fp32_modules_strict = ["codec_model"] def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py index a9b86c6e2c..4929c9e4ba 100644 --- a/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py @@ -252,7 +252,7 @@ class KyutaiSpeechToTextModel(MoshiModel): class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMixin, PreTrainedModel): - _keep_in_fp32_modules = ["codec_model"] + _keep_in_fp32_modules_strict = ["codec_model"] def __init__(self, config): super().__init__(config) diff --git a/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py b/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py index 822bc872bc..780658c77a 100644 --- a/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py +++ b/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py @@ -30,6 +30,7 @@ from transformers import ( ) from transformers.testing_utils import ( cleanup, + require_accelerate, require_torch, require_torch_accelerator, require_torch_sdpa, @@ -615,6 +616,81 @@ class KyutaiSpeechToTextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3) +@require_torch +@require_accelerate +@slow +class KyutaiSpeechToTextBf16Test(unittest.TestCase): + def test_bf16_fp32_conversion(self): + r""" + A test to check whether the argument `keep_in_fp32_modules` correctly does its job + """ + model_checkpoint = "kyutai/stt-2.6b-en-trfs" + orig_import = __import__ + accelerate_mock = unittest.mock.Mock() + + # mock import of accelerate + def import_accelerate_mock(name, *args, **kwargs): + if name == "accelerate": + if accelerate_available: + return accelerate_mock + else: + raise ImportError + return orig_import(name, *args, **kwargs) + + # Load without using `accelerate` + with unittest.mock.patch("builtins.__import__", side_effect=import_accelerate_mock): + accelerate_available = False + + model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained( + model_checkpoint, torch_dtype=torch.float16 + ) + self.assertTrue(model.codec_model.dtype == torch.float32) + self.assertTrue(model.model.dtype == torch.float16) + self.assertTrue(model.lm_head.weight.data.dtype == torch.float16) + + # Load without in bf16 + model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained( + model_checkpoint, torch_dtype=torch.bfloat16 + ) + self.assertTrue(model.codec_model.dtype == torch.float32) + self.assertTrue(model.model.dtype == torch.bfloat16) + self.assertTrue(model.lm_head.weight.data.dtype == torch.bfloat16) + + # Load using `accelerate` in bf16 + model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained( + model_checkpoint, torch_dtype=torch.bfloat16, device_map="auto" + ) + self.assertTrue(model.codec_model.dtype == torch.float32) + self.assertTrue(model.model.dtype == torch.bfloat16) + self.assertTrue(model.lm_head.weight.data.dtype == torch.bfloat16) + + # Load using `accelerate` in bf16 + model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained( + model_checkpoint, + torch_dtype=torch.bfloat16, + ) + self.assertTrue(model.codec_model.dtype == torch.float32) + self.assertTrue(model.model.dtype == torch.bfloat16) + self.assertTrue(model.lm_head.weight.data.dtype == torch.bfloat16) + + # Load without using `accelerate` + model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained( + model_checkpoint, + torch_dtype=torch.float16, + ) + self.assertTrue(model.codec_model.dtype == torch.float32) + self.assertTrue(model.model.dtype == torch.float16) + self.assertTrue(model.lm_head.weight.data.dtype == torch.float16) + + # Load using `accelerate` + model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained( + model_checkpoint, torch_dtype=torch.float16, device_map="auto" + ) + self.assertTrue(model.codec_model.dtype == torch.float32) + self.assertTrue(model.model.dtype == torch.float16) + self.assertTrue(model.lm_head.weight.data.dtype == torch.float16) + + class KyutaiSpeechToTextForConditionalGenerationIntegrationTests(unittest.TestCase): _dataset = None