Simplify keep_in_fp32_modules logic (#36722)

* better regex everywhere

* fix

* Update test_modeling_instructblip.py

* BC with explanations this time otherwise it makes no sense at all

* Update test_modeling_instructblip.py

* style

* CIs

* update _keep_in_fp32_modules in blip2

* Update modeling_utils.py

* Update modeling_utils.py

* style

* CIs

* add check

* trigger CIs

* Update modeling_utils.py

* trigger CIs
This commit is contained in:
Cyril Vallez
2025-03-21 16:12:59 +01:00
committed by GitHub
parent 90e2df5d55
commit dd3933dd65
3 changed files with 47 additions and 31 deletions

View File

@@ -716,7 +716,7 @@ def _infer_parameter_dtype(
model: "PreTrainedModel",
param_name: str,
empty_param: torch.Tensor,
keep_in_fp32_modules: Optional[List[str]] = None,
keep_in_fp32_regex: Optional[re.Pattern] = None,
hf_quantizer: Optional[HfQuantizer] = None,
) -> Union[bool, Optional[torch.dtype]]:
try:
@@ -733,7 +733,7 @@ def _infer_parameter_dtype(
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
# First fp32 if part of the exception list
if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(param_name):
if keep_in_fp32_regex is not None and keep_in_fp32_regex.search(param_name):
casting_dtype = torch.float32
# Then dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes
elif hf_quantizer is not None:
@@ -757,7 +757,7 @@ def _load_state_dict_into_meta_model(
cpu_offload_index: Optional[Dict] = None,
hf_quantizer: Optional[HfQuantizer] = None,
is_safetensors: bool = False,
keep_in_fp32_modules: Optional[List[str]] = None,
keep_in_fp32_regex: Optional[re.Pattern] = None,
unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
) -> Tuple[Optional[Dict], Optional[Dict]]:
@@ -795,7 +795,7 @@ def _load_state_dict_into_meta_model(
model,
param_name,
empty_param,
keep_in_fp32_modules,
keep_in_fp32_regex,
hf_quantizer,
)
@@ -1284,7 +1284,7 @@ def _get_device_map(
max_memory: Optional[Dict],
hf_quantizer: Optional[HfQuantizer],
torch_dtype: Optional[torch.dtype],
keep_in_fp32_modules: Optional[List[str]],
keep_in_fp32_regex: Optional[re.Pattern],
) -> Dict:
"""Compute the final `device_map` to use if we passed a value in ['auto', 'balanced', 'balanced_low_0', 'sequential'].
Otherwise, we check for any device inconsistencies in the device_map.
@@ -1293,13 +1293,9 @@ def _get_device_map(
special_dtypes = {}
if hf_quantizer is not None:
special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
if keep_in_fp32_modules is not None:
if keep_in_fp32_regex is not None:
special_dtypes.update(
{
name: torch.float32
for name, _ in model.named_parameters()
if any(m in name for m in keep_in_fp32_modules)
}
{name: torch.float32 for name, _ in model.named_parameters() if keep_in_fp32_regex.search(name)}
)
target_dtype = torch_dtype
@@ -1911,6 +1907,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
self.init_weights()
self._backward_compatibility_gradient_checkpointing()
# Make sure the modules correctly exist if the flag is active
if self._keep_in_fp32_modules is not None:
all_parameters = {name for name, _ in self.named_parameters() if len(name) > 0}
unique_module_names = set()
# Get all unique module names in the module graph, without the prefixes
for param in all_parameters:
unique_module_names.update(
[name for name in param.split(".") if not name.isnumeric() and name not in ["weight", "bias"]]
)
# Check that every module in the keep_in_fp32 list is part of the module graph
for module in self._keep_in_fp32_modules:
if module not in unique_module_names:
raise ValueError(
f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in"
f" {self.__class__.__name__}"
)
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
if self.base_model is self:
self._pp_plan = (
@@ -4412,15 +4425,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config = model.config
# Find fp32 modules if needed
keep_in_fp32_modules = None
if model._keep_in_fp32_modules is not None:
if is_accelerate_available() and not is_deepspeed_zero3_enabled():
low_cpu_mem_usage = True
keep_in_fp32_modules = model._keep_in_fp32_modules if len(model._keep_in_fp32_modules) > 0 else None
keep_in_fp32_regex = None
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
# in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
# step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
if model._keep_in_fp32_modules is not None and (
torch_dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
):
# Only the path with `low_cpu_mem_usage` will check every param for the correct dtype
low_cpu_mem_usage = True
# We need to match exact layers, so we add either `.` on each side, or start/end of string
keep_in_fp32_regex = re.compile(
"|".join([rf"((^|\.){module}($|\.))" for module in model._keep_in_fp32_modules])
)
if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules
)
# We store the original dtype for quantized models as we cannot easily retrieve it
@@ -4431,9 +4452,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Prepare the full device map
if device_map is not None:
device_map = _get_device_map(
model, device_map, max_memory, hf_quantizer, torch_dtype, keep_in_fp32_modules
)
device_map = _get_device_map(model, device_map, max_memory, hf_quantizer, torch_dtype, keep_in_fp32_regex)
# Finalize model weight initialization
if from_tf:
@@ -4465,7 +4484,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
keep_in_fp32_regex=keep_in_fp32_regex,
device_mesh=device_mesh,
key_mapping=key_mapping,
weights_only=weights_only,
@@ -4674,7 +4693,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
offload_state_dict: Optional[bool] = None,
dtype: Optional[torch.dtype] = None,
hf_quantizer: Optional[HfQuantizer] = None,
keep_in_fp32_modules: Optional[List[str]] = None,
keep_in_fp32_regex: Optional[re.Pattern] = None,
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
key_mapping: Optional[Dict[str, str]] = None,
weights_only: bool = True,
@@ -4736,10 +4755,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model._initialize_missing_keys(checkpoint_keys, ignore_mismatched_sizes, is_quantized)
# Set some modules to fp32 if needed
if keep_in_fp32_modules is not None:
keep_in_fp32_modules = re.compile("|".join([re.escape(module) for module in keep_in_fp32_modules]))
if keep_in_fp32_regex is not None:
for name, param in model.named_parameters():
if keep_in_fp32_modules.search(name):
if keep_in_fp32_regex.search(name):
# param = param.to(torch.float32) does not work here as only in the local scope.
param.data = param.data.to(torch.float32)
@@ -4894,7 +4912,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
cpu_offload_index=cpu_offload_index,
hf_quantizer=hf_quantizer,
is_safetensors=is_offloaded_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
keep_in_fp32_regex=keep_in_fp32_regex,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
)
@@ -4951,7 +4969,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
}
for name, param in parameters_to_initialize.items():
# First move data to correct
to_contiguous, casting_dtype = _infer_parameter_dtype(model, name, param, keep_in_fp32_modules)
to_contiguous, casting_dtype = _infer_parameter_dtype(model, name, param, keep_in_fp32_regex)
shard_and_distribute_module(
model,
param.to(tp_device),