Simplify keep_in_fp32_modules logic (#36722)
* better regex everywhere * fix * Update test_modeling_instructblip.py * BC with explanations this time otherwise it makes no sense at all * Update test_modeling_instructblip.py * style * CIs * update _keep_in_fp32_modules in blip2 * Update modeling_utils.py * Update modeling_utils.py * style * CIs * add check * trigger CIs * Update modeling_utils.py * trigger CIs
This commit is contained in:
@@ -716,7 +716,7 @@ def _infer_parameter_dtype(
|
|||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
param_name: str,
|
param_name: str,
|
||||||
empty_param: torch.Tensor,
|
empty_param: torch.Tensor,
|
||||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
keep_in_fp32_regex: Optional[re.Pattern] = None,
|
||||||
hf_quantizer: Optional[HfQuantizer] = None,
|
hf_quantizer: Optional[HfQuantizer] = None,
|
||||||
) -> Union[bool, Optional[torch.dtype]]:
|
) -> Union[bool, Optional[torch.dtype]]:
|
||||||
try:
|
try:
|
||||||
@@ -733,7 +733,7 @@ def _infer_parameter_dtype(
|
|||||||
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
|
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
|
||||||
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
|
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
|
||||||
# First fp32 if part of the exception list
|
# First fp32 if part of the exception list
|
||||||
if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(param_name):
|
if keep_in_fp32_regex is not None and keep_in_fp32_regex.search(param_name):
|
||||||
casting_dtype = torch.float32
|
casting_dtype = torch.float32
|
||||||
# Then dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes
|
# Then dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes
|
||||||
elif hf_quantizer is not None:
|
elif hf_quantizer is not None:
|
||||||
@@ -757,7 +757,7 @@ def _load_state_dict_into_meta_model(
|
|||||||
cpu_offload_index: Optional[Dict] = None,
|
cpu_offload_index: Optional[Dict] = None,
|
||||||
hf_quantizer: Optional[HfQuantizer] = None,
|
hf_quantizer: Optional[HfQuantizer] = None,
|
||||||
is_safetensors: bool = False,
|
is_safetensors: bool = False,
|
||||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
keep_in_fp32_regex: Optional[re.Pattern] = None,
|
||||||
unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items
|
unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items
|
||||||
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
|
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
|
||||||
) -> Tuple[Optional[Dict], Optional[Dict]]:
|
) -> Tuple[Optional[Dict], Optional[Dict]]:
|
||||||
@@ -795,7 +795,7 @@ def _load_state_dict_into_meta_model(
|
|||||||
model,
|
model,
|
||||||
param_name,
|
param_name,
|
||||||
empty_param,
|
empty_param,
|
||||||
keep_in_fp32_modules,
|
keep_in_fp32_regex,
|
||||||
hf_quantizer,
|
hf_quantizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1284,7 +1284,7 @@ def _get_device_map(
|
|||||||
max_memory: Optional[Dict],
|
max_memory: Optional[Dict],
|
||||||
hf_quantizer: Optional[HfQuantizer],
|
hf_quantizer: Optional[HfQuantizer],
|
||||||
torch_dtype: Optional[torch.dtype],
|
torch_dtype: Optional[torch.dtype],
|
||||||
keep_in_fp32_modules: Optional[List[str]],
|
keep_in_fp32_regex: Optional[re.Pattern],
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Compute the final `device_map` to use if we passed a value in ['auto', 'balanced', 'balanced_low_0', 'sequential'].
|
"""Compute the final `device_map` to use if we passed a value in ['auto', 'balanced', 'balanced_low_0', 'sequential'].
|
||||||
Otherwise, we check for any device inconsistencies in the device_map.
|
Otherwise, we check for any device inconsistencies in the device_map.
|
||||||
@@ -1293,13 +1293,9 @@ def _get_device_map(
|
|||||||
special_dtypes = {}
|
special_dtypes = {}
|
||||||
if hf_quantizer is not None:
|
if hf_quantizer is not None:
|
||||||
special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
|
special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
|
||||||
if keep_in_fp32_modules is not None:
|
if keep_in_fp32_regex is not None:
|
||||||
special_dtypes.update(
|
special_dtypes.update(
|
||||||
{
|
{name: torch.float32 for name, _ in model.named_parameters() if keep_in_fp32_regex.search(name)}
|
||||||
name: torch.float32
|
|
||||||
for name, _ in model.named_parameters()
|
|
||||||
if any(m in name for m in keep_in_fp32_modules)
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
target_dtype = torch_dtype
|
target_dtype = torch_dtype
|
||||||
@@ -1911,6 +1907,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
self.init_weights()
|
self.init_weights()
|
||||||
self._backward_compatibility_gradient_checkpointing()
|
self._backward_compatibility_gradient_checkpointing()
|
||||||
|
|
||||||
|
# Make sure the modules correctly exist if the flag is active
|
||||||
|
if self._keep_in_fp32_modules is not None:
|
||||||
|
all_parameters = {name for name, _ in self.named_parameters() if len(name) > 0}
|
||||||
|
unique_module_names = set()
|
||||||
|
# Get all unique module names in the module graph, without the prefixes
|
||||||
|
for param in all_parameters:
|
||||||
|
unique_module_names.update(
|
||||||
|
[name for name in param.split(".") if not name.isnumeric() and name not in ["weight", "bias"]]
|
||||||
|
)
|
||||||
|
# Check that every module in the keep_in_fp32 list is part of the module graph
|
||||||
|
for module in self._keep_in_fp32_modules:
|
||||||
|
if module not in unique_module_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in"
|
||||||
|
f" {self.__class__.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
||||||
if self.base_model is self:
|
if self.base_model is self:
|
||||||
self._pp_plan = (
|
self._pp_plan = (
|
||||||
@@ -4412,15 +4425,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
config = model.config
|
config = model.config
|
||||||
|
|
||||||
# Find fp32 modules if needed
|
# Find fp32 modules if needed
|
||||||
keep_in_fp32_modules = None
|
keep_in_fp32_regex = None
|
||||||
if model._keep_in_fp32_modules is not None:
|
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
|
||||||
if is_accelerate_available() and not is_deepspeed_zero3_enabled():
|
# in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
|
||||||
low_cpu_mem_usage = True
|
# step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
|
||||||
keep_in_fp32_modules = model._keep_in_fp32_modules if len(model._keep_in_fp32_modules) > 0 else None
|
if model._keep_in_fp32_modules is not None and (
|
||||||
|
torch_dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
|
||||||
|
):
|
||||||
|
# Only the path with `low_cpu_mem_usage` will check every param for the correct dtype
|
||||||
|
low_cpu_mem_usage = True
|
||||||
|
# We need to match exact layers, so we add either `.` on each side, or start/end of string
|
||||||
|
keep_in_fp32_regex = re.compile(
|
||||||
|
"|".join([rf"((^|\.){module}($|\.))" for module in model._keep_in_fp32_modules])
|
||||||
|
)
|
||||||
|
|
||||||
if hf_quantizer is not None:
|
if hf_quantizer is not None:
|
||||||
hf_quantizer.preprocess_model(
|
hf_quantizer.preprocess_model(
|
||||||
model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
|
model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules
|
||||||
)
|
)
|
||||||
|
|
||||||
# We store the original dtype for quantized models as we cannot easily retrieve it
|
# We store the original dtype for quantized models as we cannot easily retrieve it
|
||||||
@@ -4431,9 +4452,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
# Prepare the full device map
|
# Prepare the full device map
|
||||||
if device_map is not None:
|
if device_map is not None:
|
||||||
device_map = _get_device_map(
|
device_map = _get_device_map(model, device_map, max_memory, hf_quantizer, torch_dtype, keep_in_fp32_regex)
|
||||||
model, device_map, max_memory, hf_quantizer, torch_dtype, keep_in_fp32_modules
|
|
||||||
)
|
|
||||||
|
|
||||||
# Finalize model weight initialization
|
# Finalize model weight initialization
|
||||||
if from_tf:
|
if from_tf:
|
||||||
@@ -4465,7 +4484,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
offload_state_dict=offload_state_dict,
|
offload_state_dict=offload_state_dict,
|
||||||
dtype=torch_dtype,
|
dtype=torch_dtype,
|
||||||
hf_quantizer=hf_quantizer,
|
hf_quantizer=hf_quantizer,
|
||||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
keep_in_fp32_regex=keep_in_fp32_regex,
|
||||||
device_mesh=device_mesh,
|
device_mesh=device_mesh,
|
||||||
key_mapping=key_mapping,
|
key_mapping=key_mapping,
|
||||||
weights_only=weights_only,
|
weights_only=weights_only,
|
||||||
@@ -4674,7 +4693,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
offload_state_dict: Optional[bool] = None,
|
offload_state_dict: Optional[bool] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
hf_quantizer: Optional[HfQuantizer] = None,
|
hf_quantizer: Optional[HfQuantizer] = None,
|
||||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
keep_in_fp32_regex: Optional[re.Pattern] = None,
|
||||||
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
|
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
|
||||||
key_mapping: Optional[Dict[str, str]] = None,
|
key_mapping: Optional[Dict[str, str]] = None,
|
||||||
weights_only: bool = True,
|
weights_only: bool = True,
|
||||||
@@ -4736,10 +4755,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
model._initialize_missing_keys(checkpoint_keys, ignore_mismatched_sizes, is_quantized)
|
model._initialize_missing_keys(checkpoint_keys, ignore_mismatched_sizes, is_quantized)
|
||||||
|
|
||||||
# Set some modules to fp32 if needed
|
# Set some modules to fp32 if needed
|
||||||
if keep_in_fp32_modules is not None:
|
if keep_in_fp32_regex is not None:
|
||||||
keep_in_fp32_modules = re.compile("|".join([re.escape(module) for module in keep_in_fp32_modules]))
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if keep_in_fp32_modules.search(name):
|
if keep_in_fp32_regex.search(name):
|
||||||
# param = param.to(torch.float32) does not work here as only in the local scope.
|
# param = param.to(torch.float32) does not work here as only in the local scope.
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
@@ -4894,7 +4912,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
cpu_offload_index=cpu_offload_index,
|
cpu_offload_index=cpu_offload_index,
|
||||||
hf_quantizer=hf_quantizer,
|
hf_quantizer=hf_quantizer,
|
||||||
is_safetensors=is_offloaded_safetensors,
|
is_safetensors=is_offloaded_safetensors,
|
||||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
keep_in_fp32_regex=keep_in_fp32_regex,
|
||||||
unexpected_keys=unexpected_keys,
|
unexpected_keys=unexpected_keys,
|
||||||
device_mesh=device_mesh,
|
device_mesh=device_mesh,
|
||||||
)
|
)
|
||||||
@@ -4951,7 +4969,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
}
|
}
|
||||||
for name, param in parameters_to_initialize.items():
|
for name, param in parameters_to_initialize.items():
|
||||||
# First move data to correct
|
# First move data to correct
|
||||||
to_contiguous, casting_dtype = _infer_parameter_dtype(model, name, param, keep_in_fp32_modules)
|
to_contiguous, casting_dtype = _infer_parameter_dtype(model, name, param, keep_in_fp32_regex)
|
||||||
shard_and_distribute_module(
|
shard_and_distribute_module(
|
||||||
model,
|
model,
|
||||||
param.to(tp_device),
|
param.to(tp_device),
|
||||||
|
|||||||
@@ -419,7 +419,6 @@ class Blip2PreTrainedModel(PreTrainedModel):
|
|||||||
"OPTDecoderLayer",
|
"OPTDecoderLayer",
|
||||||
]
|
]
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_keep_in_fp32_modules = ["query_tokens"]
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
@@ -1448,6 +1447,7 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
|
|||||||
class Blip2Model(Blip2PreTrainedModel):
|
class Blip2Model(Blip2PreTrainedModel):
|
||||||
config_class = Blip2Config
|
config_class = Blip2Config
|
||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
|
_keep_in_fp32_modules = ["query_tokens"]
|
||||||
|
|
||||||
def __init__(self, config: Blip2Config):
|
def __init__(self, config: Blip2Config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -2019,6 +2019,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
|||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_static_cache = True
|
_supports_static_cache = True
|
||||||
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
|
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
|
||||||
|
_keep_in_fp32_modules = ["query_tokens"]
|
||||||
|
|
||||||
def __init__(self, config: Blip2Config):
|
def __init__(self, config: Blip2Config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -791,15 +791,12 @@ class InstructBlipModelIntegrationTest(unittest.TestCase):
|
|||||||
num_beams=5,
|
num_beams=5,
|
||||||
max_length=256,
|
max_length=256,
|
||||||
min_length=1,
|
min_length=1,
|
||||||
top_p=0.9,
|
|
||||||
repetition_penalty=1.5,
|
repetition_penalty=1.5,
|
||||||
length_penalty=1.0,
|
length_penalty=1.0,
|
||||||
temperature=1,
|
temperature=1,
|
||||||
)
|
)
|
||||||
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
||||||
|
|
||||||
expected_outputs = [0, 37, 1023, 9850, 7, 3, 9, 388, 3575, 53, 4954, 30, 8, 223, 13, 3, 9, 4459, 4049, 16, 8, 2214, 13, 3, 9, 3164, 690, 2815, 5, 37, 388, 19, 5119, 3, 9, 4459, 8677, 28, 3, 9, 2756, 4459, 6177, 6, 11, 3, 88, 19, 338, 46, 3575, 53, 1476, 12, 743, 112, 2491, 5, 37, 1023, 19, 7225, 788, 12, 8, 685, 24, 34, 1267, 3, 9, 388, 3575, 53, 4954, 30, 8, 223, 13, 3, 9, 4049, 16, 8, 2214, 13, 3, 9, 3164, 690, 2815, 5, 94, 19, 487, 24, 8, 388, 19, 1119, 12, 1097, 540, 57, 692, 112, 10428, 30, 8, 223, 13, 8, 4049, 6, 68, 34, 19, 92, 487, 24, 3, 88, 19, 1119, 12, 1097, 97, 57, 692, 112, 10428, 30, 8, 223, 13, 8, 4049, 16, 8, 2214, 13, 3, 9, 3164, 690, 2815, 5, 3, 13865, 13, 8, 1053, 21, 8, 388, 31, 7, 2874, 6, 34, 19, 964, 24, 3, 88, 19, 1119, 12, 1097, 97, 57, 692, 112, 10428, 30, 8, 223, 13, 8, 4049, 16, 8, 2214, 13, 3, 9, 3164, 690, 2815, 5, 1] # fmt: skip
|
|
||||||
|
|
||||||
expected_outputs = [0, 37, 7225, 1023, 9850, 7, 3, 9, 388, 3575, 53, 4954, 30, 8, 223, 13, 3, 9, 4459, 4049, 16, 8, 2214, 13, 3, 9, 3164, 690, 2815, 5, 37, 388, 19, 5119, 3, 9, 4459, 8677, 28, 46, 3575, 53, 1476, 5223, 12, 34, 6, 15495, 24, 3, 88, 19, 692, 112, 293, 10428, 44, 234, 1066, 145, 338, 3, 9, 50, 1106, 3522, 144, 42, 2192, 7919, 31, 7, 5, 37, 1023, 92, 1267, 3, 9, 381, 13, 119, 3203, 16, 8, 2458, 6, 379, 14264, 6, 9256, 7, 6, 11, 11718, 7, 5, 1] # fmt: skip
|
expected_outputs = [0, 37, 7225, 1023, 9850, 7, 3, 9, 388, 3575, 53, 4954, 30, 8, 223, 13, 3, 9, 4459, 4049, 16, 8, 2214, 13, 3, 9, 3164, 690, 2815, 5, 37, 388, 19, 5119, 3, 9, 4459, 8677, 28, 46, 3575, 53, 1476, 5223, 12, 34, 6, 15495, 24, 3, 88, 19, 692, 112, 293, 10428, 44, 234, 1066, 145, 338, 3, 9, 50, 1106, 3522, 144, 42, 2192, 7919, 31, 7, 5, 37, 1023, 92, 1267, 3, 9, 381, 13, 119, 3203, 16, 8, 2458, 6, 379, 14264, 6, 9256, 7, 6, 11, 11718, 7, 5, 1] # fmt: skip
|
||||||
|
|
||||||
self.assertEqual(outputs[0].tolist(), expected_outputs)
|
self.assertEqual(outputs[0].tolist(), expected_outputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user