add _keep_in_fp32_modules_strict (#39058)

* add _keep_in_fp32_modules_strict

* complete test
This commit is contained in:
eustlb
2025-06-26 15:55:28 +02:00
committed by GitHub
parent d973e62fdd
commit 02ecdcfc0f
4 changed files with 111 additions and 17 deletions

View File

@@ -1937,7 +1937,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
_auto_class = None
_no_split_modules = None
_skip_keys_device_placement = None
_keep_in_fp32_modules = None
# the _keep_in_fp32_modules will avoid casting to anything other than float32, except bfloat16
# to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
_keep_in_fp32_modules_strict = None
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
@@ -2049,6 +2053,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
# when a different component (e.g. language_model) is used.
self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
self._no_split_modules = self._no_split_modules or []
@@ -2061,7 +2066,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
self._backward_compatibility_gradient_checkpointing()
# Make sure the modules correctly exist if the flag is active
if self._keep_in_fp32_modules is not None:
if self._keep_in_fp32_modules is not None or self._keep_in_fp32_modules_strict is not None:
all_parameters = {name for name, _ in self.named_parameters() if len(name) > 0}
unique_module_names = set()
# Get all unique module names in the module graph, without the prefixes
@@ -2070,12 +2075,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
[name for name in param.split(".") if not name.isnumeric() and name not in ["weight", "bias"]]
)
# Check that every module in the keep_in_fp32 list is part of the module graph
for module in self._keep_in_fp32_modules:
if module not in unique_module_names:
raise ValueError(
f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in"
f" {self.__class__.__name__}"
)
if self._keep_in_fp32_modules is not None:
for module in self._keep_in_fp32_modules:
if module not in unique_module_names:
raise ValueError(
f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in"
f" {self.__class__.__name__}"
)
if self._keep_in_fp32_modules_strict is not None:
for module in self._keep_in_fp32_modules_strict:
if module not in unique_module_names:
raise ValueError(
f"{module} was specified in the `_keep_in_fp32_modules_strict` list, but is not part of the modules in"
f" {self.__class__.__name__}"
)
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
@@ -4757,20 +4771,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
config = model.config
# Find fp32 modules if needed
keep_in_fp32_regex = None
keep_in_fp32_modules = []
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
# in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
# step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
# Update: to extend _keep_in_fp32_modules flag feature, it can also be used to force modules that should stay in fp32
if model._keep_in_fp32_modules is not None and (
torch_dtype == torch.float16
or torch_dtype == torch.bfloat16
or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
torch_dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
):
keep_in_fp32_modules.extend(model._keep_in_fp32_modules)
if model._keep_in_fp32_modules_strict is not None and (
torch_dtype == torch.float16 or torch_dtype == torch.bfloat16
):
keep_in_fp32_modules.extend(model._keep_in_fp32_modules_strict)
keep_in_fp32_regex = None
if keep_in_fp32_modules:
# We need to match exact layers, so we add either `.` on each side, or start/end of string
keep_in_fp32_regex = re.compile(
"|".join([rf"((^|\.){module}($|\.))" for module in model._keep_in_fp32_modules])
)
keep_in_fp32_regex = re.compile("|".join([rf"((^|\.){module}($|\.))" for module in keep_in_fp32_modules]))
if hf_quantizer is not None:
hf_quantizer.preprocess_model(