Refactor some core stuff (#36539)

* some config changes

* update

* current state

* update

* update

* updates and cleanup

* something that works

* fixup

* fixes

* nits

* nit

* nits and fix

* Update src/transformers/integrations/tensor_parallel.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Update src/transformers/integrations/tensor_parallel.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* cleanup

* style

* safe import

* fix

* updates

* rename stuff an clean

* style

* small updates

* ups

* oups

* nit

* protect imports

* update tp

* rodfl

* arf

* turbo nit on init

* fix import error

* frumble gumbgle

* try to fix the import error

* should fix the non model test

* update keep in float32

* update

* fix

* nits

* fix subvconfigs

* test was weird

* nit

* fix failing test

* fix instruct blip

* fixes

* style

* x.com

* fix overwrite

* ok last bit of failing test

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Arthur
2025-03-11 09:26:28 +01:00
committed by GitHub
parent e9756cdbc7
commit 1c4b62b219
9 changed files with 704 additions and 116 deletions

View File

@@ -54,17 +54,20 @@ from .integrations.deepspeed import _load_state_dict_into_zero3_model
from .integrations.flash_attention import flash_attention_forward
from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.tensor_parallel import (
SUPPORTED_TP_STYLES,
shard_and_distribute_module,
translate_to_torch_parallel_style,
)
from .loss.loss_utils import LOSS_MAPPING
from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
distribute_module,
find_pruneable_heads_and_indices,
id_tensor_storage,
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
translate_to_torch_parallel_style,
)
from .quantizers import AutoHfQuantizer, HfQuantizer
from .quantizers.quantizers_utils import get_module_from_name
@@ -151,6 +154,7 @@ logger = logging.get_logger(__name__)
_init_weights = True
_is_quantized = False
_is_ds_init_called = False
_torch_distributed_available = torch.distributed.is_available()
def is_fsdp_enabled():
@@ -181,8 +185,6 @@ else:
if is_peft_available():
from .utils import find_adapter_config_file
if is_torch_greater_or_equal("2.5"):
from torch.distributed.tensor import DTensor, Shard
SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
@@ -756,6 +758,40 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
setattr(submodule, param_name, new_val)
def fix_tensor_type_and_device(
model, param_name, param, dtype=None, keep_in_fp32_modules=None
) -> Union[str, torch.dtype]:
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
old_param = model
if "." in param_name:
pre, _ = param_name.rsplit(".", 1)
old_param = model.get_submodule(pre)
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
param_casting_dtype = None
is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
if param.dtype.is_floating_point and not is_param_float8_e4m3fn:
if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(param_name):
param_casting_dtype = torch.float32
elif dtype is not None:
param_casting_dtype = dtype
elif old_param is not None:
param_casting_dtype = old_param.dtype
return old_param is not None and old_param.is_contiguous(), param_casting_dtype
else:
return False, None
return
@torch.no_grad()
def _load_state_dict_into_meta_model(
model: torch.nn.Module,
@@ -787,18 +823,12 @@ def _load_state_dict_into_meta_model(
It also initialize tensor parallelism for each module if needed.
"""
tensor_device = None
tensor_device = "cpu"
if device_map is not None and device_map.get("", None) is not None:
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
if device_map is not None:
device_map_regex = "|".join(sorted(device_map.keys(), reverse=True))
# we need this later to initialize tensor parallelism
if device_mesh is not None:
full_tp_plan = model.config.base_model_tp_plan
for submodule in model.modules():
full_tp_plan.update(getattr(submodule, "_tp_plan", {}))
file_pointer = None
bin_state_dict = None
if shard_file.endswith(".safetensors"):
@@ -818,8 +848,6 @@ def _load_state_dict_into_meta_model(
is_quantized = hf_quantizer is not None
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
for serialized_param_name, empty_param in state_dict.items():
# serialized_param_name is the raw, serialized name
# fixed_param_name is the model's equivalent
@@ -829,87 +857,37 @@ def _load_state_dict_into_meta_model(
continue
# we need to use serialized_param_name as file pointer is untouched
param = (
file_pointer.get_slice(serialized_param_name)
if shard_file.endswith(".safetensors")
else bin_state_dict[serialized_param_name]
if shard_file.endswith(".safetensors"):
param = file_pointer.get_slice(serialized_param_name)
elif shard_file.endswith(".gguf"):
param = empty_param # For gguf the dict is actually not empty!
else:
param = bin_state_dict[serialized_param_name]
to_contiguous, param_casting_dtype = fix_tensor_type_and_device(
model,
param_name=fixed_param_name,
param=empty_param,
dtype=dtype,
keep_in_fp32_modules=keep_in_fp32_modules,
)
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
old_param = model
splits = fixed_param_name.split(".")
for split in splits:
# We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
old_param = getattr(old_param, split, None)
if old_param is None:
break
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
param_casting_dtype = None
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(fixed_param_name):
param_casting_dtype = torch.float32
elif dtype is not None:
param_casting_dtype = dtype
elif old_param is not None:
param_casting_dtype = old_param.dtype
if device_mesh is not None: # In this case, the param is already on the correct device!
module_to_tp, param_type = get_module_from_name(model, fixed_param_name)
current_module_plan = None
full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+")
if plan := re.search(full_tp_plan_, fixed_param_name):
match = re.sub("[0-9]+", "*", plan[0])
current_module_plan = full_tp_plan[match]
if current_module_plan is not None:
tp_layer = translate_to_torch_parallel_style(current_module_plan)
rank = tensor_device
row, col = empty_param.shape
if "rowwise" == current_module_plan:
param = param[:, rank * (col // device_mesh.size()) : (rank + 1) * (col // device_mesh.size())]
shard = Shard(1)
tp_layer.desired_input_layouts = (Shard(-1),)
elif "colwise" == current_module_plan:
param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :]
shard = Shard(0)
else:
param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :]
shard = Shard(0)
if param_casting_dtype is not None:
param = param.to(param_casting_dtype)
if old_param.is_contiguous():
param = param.contiguous()
local_parameter = DTensor.from_local(
param,
device_mesh=device_mesh,
placements=[shard] * device_mesh.ndim,
)
if isinstance(module_to_tp.weight, nn.Parameter):
local_parameter = torch.nn.Parameter(local_parameter)
module_to_tp.weight = local_parameter
input_fn = partial(tp_layer._prepare_input_fn, tp_layer.input_layouts, tp_layer.desired_input_layouts)
output_fn = partial(tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output)
distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn)
else:
param = param[:]
if old_param is not None and old_param.is_contiguous():
param = param.contiguous()
module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)
shard_and_distribute_module(
model,
param,
empty_param,
fixed_param_name,
param_casting_dtype,
to_contiguous,
tensor_device, # the rank
device_mesh,
)
else:
param = param[:]
if param_casting_dtype is not None:
param = param.to(param_casting_dtype)
if old_param is not None and old_param.is_contiguous():
if to_contiguous:
param = param.contiguous()
if device_map is None:
@@ -966,6 +944,7 @@ def _load_state_dict_into_meta_model(
val_kwargs["requires_grad"] = False
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
setattr(module, param_type, value)
if file_pointer is not None:
file_pointer.__exit__(None, None, None)
@@ -1409,7 +1388,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# A tensor parallel plan to be applied to the model when TP is enabled. For
# top-level models, this attribute is currently defined in respective model
# code. For base models, this attribute comes from
# `config.base_model_tp_plan` during `post_init`.
# `config.base_model_tp_plan` during `__init__`.
# It should identify the layers exactly: if you want to TP model.language_model.layers.fc1
# by passing `tp_plan` to the init, it should be {"model.language_model.layers.fc1":"colwise"}
# for example.
_tp_plan = None
# A pipeline parallel plan specifying the layers which may not be present
@@ -1475,6 +1457,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# when a different component (e.g. language_model) is used.
self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
self._no_split_modules = self._no_split_modules or []
def post_init(self):
"""
A method executed at the end of each Transformer model initialization, to execute code that needs the model's
@@ -1482,11 +1466,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"""
self.init_weights()
self._backward_compatibility_gradient_checkpointing()
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
if self.base_model is self:
self._tp_plan = self.config.base_model_tp_plan
self._pp_plan = self.config.base_model_pp_plan
self._tp_plan = self._tp_plan or self.config.base_model_tp_plan or {}
for name, module in self.named_children():
if plan := getattr(module, "_tp_plan", None):
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.items()})
if self._tp_plan is not None and is_torch_greater_or_equal("2.3"):
for _, v in self._tp_plan.items():
if v not in SUPPORTED_TP_STYLES:
raise ValueError(
f"Unsupported tensor parallel style {v}. Supported styles are {SUPPORTED_TP_STYLES}"
)
def dequantize(self):
"""
Potentially dequantize the model in case it has been quantized by a quantization method that support
@@ -4315,7 +4311,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model = cls(config, *model_args, **model_kwargs)
if device_mesh is not None and not model.supports_tp_plan:
raise NotImplementedError("This model does not have a tensor parallel plan.")
if config.base_model_tp_plan is None and config.get_text_config().base_model_tp_plan is None:
raise NotImplementedError("This model does not have a tensor parallel plan.")
# make sure we use the model's config since the __init__ call might have copied it
config = model.config
@@ -4453,7 +4450,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model,
state_dict,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
resolved_archive_file or gguf_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
@@ -4565,7 +4562,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
@staticmethod
def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]:
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
# Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert)
# This rename is logged.
if key.endswith("LayerNorm.beta"):
@@ -4590,6 +4586,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return key, False
def rename_key(self, key):
"""
When we load a LlamaModel from a checkpoint made using LlamaForCausalLM, the keys have an extra
prefix, which can be accessed in the `LlamaModel` via the `self.base_model_prefix` attribute.
But, what if there is an extra layer on top of it? You load a MistralModel from a LlavaForConditionalGeneration?
In that what you actually want is to cut whatever is left of the key.
"""
new_key = key
if len(self.base_model_prefix) > 0:
if not hasattr(self, self.base_model_prefix) and key.startswith(self.base_model_prefix):
@@ -4940,7 +4943,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
resolved_archive_file=resolved_archive_file,
shard_file=resolved_archive_file,
weights_only=weights_only,
)
else:
@@ -5019,7 +5022,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
start_prefix,
prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
@@ -5898,10 +5901,21 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
for param_name, device in accelerator_device_map.items():
try:
param = model.get_parameter(param_name)
param = getattr(model, param_name)
except AttributeError:
param = model.get_buffer(param_name)
parameter_count[device] += int(math.prod(param.shape) * allocation_factor)
if "." in param_name:
param_name, param_type = param_name.rsplit(".", 1)
param = getattr(model.get_submodule(param_name), param_type)
else:
param = model.get_buffer(param_name)
param_size = int(math.prod(param.shape) * allocation_factor)
if _torch_distributed_available and torch.distributed.is_initialized():
generic_name = re.sub(r"\d+", "*", param_name)
param_size //= torch.distributed.get_world_size() if not model._tp_plan.get(generic_name, False) else 1
parameter_count[device] += param_size
dtype = dtype if dtype is not None else torch.float32