From 1c4b62b219323a31011bac3bd3cece7675d9e4c3 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 11 Mar 2025 09:26:28 +0100 Subject: [PATCH] Refactor some core stuff (#36539) * some config changes * update * current state * update * update * updates and cleanup * something that works * fixup * fixes * nits * nit * nits and fix * Update src/transformers/integrations/tensor_parallel.py Co-authored-by: Lysandre Debut * Update src/transformers/integrations/tensor_parallel.py Co-authored-by: Lysandre Debut * cleanup * style * safe import * fix * updates * rename stuff an clean * style * small updates * ups * oups * nit * protect imports * update tp * rodfl * arf * turbo nit on init * fix import error * frumble gumbgle * try to fix the import error * should fix the non model test * update keep in float32 * update * fix * nits * fix subvconfigs * test was weird * nit * fix failing test * fix instruct blip * fixes * style * x.com * fix overwrite * ok last bit of failing test --------- Co-authored-by: Lysandre Debut --- src/transformers/configuration_utils.py | 22 +- src/transformers/integrations/__init__.py | 26 +- .../integrations/tensor_parallel.py | 544 ++++++++++++++++++ src/transformers/modeling_utils.py | 208 +++---- .../models/blip_2/modeling_blip_2.py | 8 +- .../instructblip/modeling_instructblip.py | 2 +- .../modeling_instructblipvideo.py | 2 +- .../models/llava/test_configuration_llava.py | 8 +- .../test_tensor_parallel.py} | 0 9 files changed, 704 insertions(+), 116 deletions(-) create mode 100644 src/transformers/integrations/tensor_parallel.py rename tests/{tp/test_tp.py => tensor_parallel/test_tensor_parallel.py} (100%) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 3e7effc4ab..9b29b0d2b9 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -824,25 +824,27 @@ class PretrainedConfig(PushToHubMixin): serializable_config_dict = {} - # only serialize values that differ from the default config + # Only serialize values that differ from the default config, + # except always keep the 'config' attribute. for key, value in config_dict.items(): if ( isinstance(getattr(self, key, None), PretrainedConfig) and key in class_config_dict and isinstance(class_config_dict[key], dict) + or key in self.sub_configs ): # For nested configs we need to clean the diff recursively - diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None)) + diff = recursive_diff_dict(value, default_config_dict, config_obj=getattr(self, key, None)) if "model_type" in value: # Needs to be set even if it's not in the diff diff["model_type"] = value["model_type"] - if len(diff) > 0: - serializable_config_dict[key] = diff + serializable_config_dict[key] = diff elif ( key not in default_config_dict or key == "transformers_version" + or key == "vocab_file" or value != default_config_dict[key] - or (key in class_config_dict and value != class_config_dict[key]) + or (key in default_config_dict and value != class_config_dict.get(key, value)) ): serializable_config_dict[key] = value @@ -867,6 +869,9 @@ class PretrainedConfig(PushToHubMixin): if "base_model_pp_plan" in serializable_config_dict: del serializable_config_dict["base_model_pp_plan"] + if "_name_or_path" in serializable_config_dict: + del serializable_config_dict["_name_or_path"] + return serializable_config_dict def to_dict(self) -> Dict[str, Any]: @@ -1178,6 +1183,8 @@ def recursive_diff_dict(dict_a, dict_b, config_obj=None): """ Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the values from `dict_a` that are different from values in `dict_b`. + + dict_b : the default config dictionnary. We want to remove values that are in this one """ diff = {} default = config_obj.__class__().to_dict() if config_obj is not None else {} @@ -1185,9 +1192,8 @@ def recursive_diff_dict(dict_a, dict_b, config_obj=None): obj_value = getattr(config_obj, str(key), None) if isinstance(obj_value, PretrainedConfig) and key in dict_b and isinstance(dict_b[key], dict): diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value) - if len(diff_value) > 0: - diff[key] = diff_value - elif key not in dict_b or value != dict_b[key] or key not in default or value != default[key]: + diff[key] = diff_value + elif key not in dict_b or (value != default[key]): diff[key] = value return diff diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 5c94349cbc..7e413cd064 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import TYPE_CHECKING -from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available +from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_torch_greater_or_equal _import_structure = { @@ -128,6 +128,18 @@ else: "convert_and_export_with_cache", ] +try: + if not is_torch_greater_or_equal("2.3"): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["tensor_parallel"] = [ + "shard_and_distribute_module", + "SUPPORTED_TP_STYLES", + "translate_to_torch_parallel_style", + ] + if TYPE_CHECKING: from .aqlm import replace_with_aqlm_linear from .awq import ( @@ -231,6 +243,18 @@ if TYPE_CHECKING: else: from .executorch import TorchExportableModuleWithStaticCache, convert_and_export_with_cache + try: + if not is_torch_greater_or_equal("2.3"): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .tensor_parallel import ( + SUPPORTED_TP_STYLES, + shard_and_distribute_module, + translate_to_torch_parallel_style, + ) + else: import sys diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py new file mode 100644 index 0000000000..9e8a0dec76 --- /dev/null +++ b/src/transformers/integrations/tensor_parallel.py @@ -0,0 +1,544 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import re +from functools import lru_cache, partial +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + +from ..utils import is_torch_greater_or_equal, logging + + +ALL_LAYERNORM_LAYERS = [nn.LayerNorm] + +logger = logging.get_logger(__name__) + +# Cache this result has it's a C FFI call which can be pretty time-consuming +_torch_distributed_available = torch.distributed.is_available() + + +if is_torch_greater_or_equal("2.5") and _torch_distributed_available: + from torch.distributed.tensor import DTensor, Placement, Replicate, Shard + + +def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: + """ + Convert block count or proportions to block sizes. + + This function accepts + + - The number of blocks (int), in which case the block size is + total_size//blocks; or + - A list of block sizes (List[int]). + + In the second case, if sum(blocks) < total_size, the ratios between + the block sizes will be preserved. For instance, if blocks is + [2, 1, 1] and total_size is 1024, the returned block sizes are + [512, 256, 256]. + """ + if isinstance(blocks, list): + total_blocks = sum(blocks) + assert total_size % total_blocks == 0, f"Cannot split {total_size} in proportional blocks: {blocks}" + part_size = total_size // total_blocks + return [part_size * block for block in blocks] + else: + assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}" + single_size = total_size // blocks + return [single_size] * blocks + + +def get_packed_weights(param, empty_param, device_mesh, rank, dim): + """ + When weights are packed (gate_up_proj), we need to make sure each shard gets its correct share. + So if you have: gate_proj ( 16, 5120, 8190) + and up_proj ( 16, 5120, 8190) + packed as gate_up_proj ( 16, 5120, 2 * 8190) + And you shard along the last dimension, you need to interleave the gate and up values: + + Now, if we shard along the last dimension across TP_size (Tensor Parallelism size), we must interleave the values from gate and up projections correctly. + + Let's take TP_size = 4 for an example: + + Packed tensor `gate_up_proj` + --------------------------------------------------------------- + [ G0 G1 G2 G3 | G4 G5 G6 G7 | ... | U0 U1 U2 U3 | U4 U5 U6 U7 | ... ] + ↑─────────────↑ ↑─────────────↑ ↑─────────────↑ ↑─────────────↑ + Gate Slice 0 Gate Slice 1 Up Slice 0 Up Slice 1 + + Explanation: + - The first half of the tensor (left of the center) holds the gate_proj values. + - The second half (right of the center) holds the up_proj values. + - For TP=4, we divide each half into 4 slices. In this example, we show two slices for brevity. + - Each shard receives one slice from the gate part and the corresponding slice from the up part. + + For instance: + • Shard 0 gets: [ Gate Slice 0, Up Slice 0 ] = [ G0, G1, G2, G3, U0, U1, U2, U3 ] + • Shard 1 gets: [ Gate Slice 1, Up Slice 1 ] = [ G4, G5, G6, G7, U4, U5, U6, U7 ] + • … and so on. + + This ensures that each shard receives an equal portion of both gate and up projections, maintaining consistency across tensor parallelism. + """ + slice_ = param + total_size = empty_param.shape[dim] + world_size = device_mesh.size() + block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=2) + + tensors_slices = [] + block_offset = 0 + for block_size in block_sizes: + shard_block_size = block_size // world_size + start = rank * shard_block_size + stop = (rank + 1) * shard_block_size + tensors_slices += range(block_offset + start, block_offset + stop) + block_offset += block_size + + if dim == 0: + tensor = slice_[tensors_slices, ...] + elif dim == 1 or dim == -2: + tensor = slice_[:, tensors_slices, ...] + elif dim == 2 or dim == -1: + tensor = slice_[..., tensors_slices] + else: + raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported") + return tensor + + +def get_tensor_shard(param, empty_param, device_mesh, rank, dim): + if dim == 0: + size_ = empty_param.shape[0] + param = param[rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size()), ...] + elif dim == 1 or dim == -2: + size_ = empty_param.shape[-2] + param = param[..., rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size()), :] + elif dim == 2 or dim == -1: + size_ = empty_param.shape[-1] + param = param[..., rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size())] + else: + raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported") + return param + + +def distribute_module( + module: nn.Module, + device_mesh=None, + input_fn=None, + output_fn=None, +) -> nn.Module: + """ + Copy pasted from torch's function but we remove the communications (partitionning) + as well as buffer registering that is similarly not efficient. + """ + if len(module._forward_pre_hooks) == 0: + if input_fn is not None: + module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh)) + if output_fn is not None: + module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh)) + return module + + +class TensorParallelLayer: + """ + General tensor parallel layer for transformers. + """ + + use_dtensor = True + + @staticmethod + def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): ... + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): ... + + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): + raise NotImplementedError + + def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: + if self.use_dtensor: + distribute_module( + module, + device_mesh, + partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts), + partial(self._prepare_output_fn, self.output_layouts, self.use_local_output), + ) + + +# use_dtensor needs to be set to false for nn.Parameter when you want to view, chunk, slice +# you name it. Whatever you want to do that is a bit unconventional, you need local tensors +class GatherParallel(TensorParallelLayer): + """ + Simple class used to define the hooks to add to a layer when we just want to gather the outputs + """ + + def __init__( + self, + *, + input_layouts: Optional[Placement] = None, + output_layouts: Optional[Placement] = None, + use_local_output: bool = True, + ): + super().__init__() + self.input_layouts = (input_layouts or Replicate(),) + self.output_layouts = output_layouts + self.desired_input_layouts = (Replicate(),) + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + if isinstance(inputs[0], DTensor): + inputs[0] = inputs[0].to_local() + return inputs + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False) + return outputs + + +class IsolatedParallel(TensorParallelLayer): + """ + This class is used to isolate computation in a TP layer from the rest of the world. + Parameters need to be LOCAL, so not dtensors + """ + + @staticmethod + def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh=None): + # annotate module input placements/sharding with input_layouts + input_tensor = inputs[0] + if isinstance(input_tensor, DTensor): + input_tensor = input_tensor.to_local() + return input_tensor + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh=None): + # TODO: figure out dynamo support for instance method and switch this to instance method + return outputs + + def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: + distribute_module( + module, + device_mesh, + partial(self._prepare_input_fn), + partial(self._prepare_output_fn), + ) + + +class ColwiseParallel(TensorParallelLayer): + """ + General tensor parallel layer for transformers. + """ + + def __init__( + self, + *, + input_layouts: Optional[Placement] = None, + output_layouts: Optional[Placement] = None, + use_local_output: bool = True, + use_dtensor=True, + ): + super().__init__() + self.input_layouts = (input_layouts or Replicate(),) + self.output_layouts = (output_layouts or Shard(-1),) + self.desired_input_layouts = (Replicate(),) + self.use_local_output = use_local_output + self.use_dtensor = use_dtensor + + @staticmethod + def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + # TODO: figure out dynamo support for instance method and switch this to instance method + # annotate module input placements/sharding with input_layouts + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) + + # transform the input layouts to the desired layouts of ColwiseParallel + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) + return input_tensor + + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): + # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only) + # means Colwise as Linear is input * weight^T + bias, where + # weight would become Shard(1) + if param_type == "bias": + parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1) + shard = [Shard(-1)] + else: + shard = [Shard(-2)] + parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2) + + parameter = parameter.to(param_casting_dtype) + if to_contiguous: + parameter = parameter.contiguous() + if self.use_dtensor: + parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False) + return nn.Parameter(parameter) + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + # outputs is a shard on last dimension DTensor, i.e. Shard(-1) + if outputs.placements != output_layouts: + outputs = outputs.redistribute(placements=output_layouts, async_op=True) + # back to local tensor + return outputs.to_local() if use_local_output else outputs + + +class PackedColwiseParallel(ColwiseParallel): + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): + # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only) + # means Colwise as Linear is input * weight^T + bias, where + # weight would become Shard(1) + parameter = get_packed_weights(param, empty_param, device_mesh, rank, -2) + parameter = parameter.to(param_casting_dtype) + if to_contiguous: + parameter = parameter.contiguous() + if self.use_dtensor: + parameter = DTensor.from_local(parameter, device_mesh, [Shard(-2)], run_check=False) + return nn.Parameter(parameter) + + +class RowwiseParallel(TensorParallelLayer): + """ + Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding. + Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules. + (i.e. MLP, Attention) + + Keyword Args: + input_layouts (Placement, optional): + The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to + become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension. + output_layouts (Placement, optional): + The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module + with the user desired layout. If not specified, the output tensor is replicated. + use_local_output (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True. + Returns: + A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module. + """ + + def __init__( + self, + *, + input_layouts: Optional[Placement] = None, + output_layouts: Optional[Placement] = None, + use_local_output: bool = True, + use_dtensor=True, + ): + super().__init__() + self.input_layouts = (input_layouts or Shard(-1),) + self.output_layouts = (output_layouts or Replicate(),) + self.use_local_output = use_local_output + self.use_dtensor = use_dtensor + + @staticmethod + def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) + + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) + return input_tensor + + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): + # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1) + # means Rowwise as nn.Linear is input * weight^T + bias, where + # weight would become Shard(0) + if param_type != "bias": + parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1) + shard = [Shard(-1)] + else: + shard = [Replicate()] + parameter = param[:] + + parameter = parameter.to(param_casting_dtype) + if to_contiguous: + parameter = parameter.contiguous() + if self.use_dtensor: + parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False) + return nn.Parameter(parameter) + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + # Rowwise sharding produces partial output, depending on output layouts: + # 1. to replicate -> allreduce + # 2. to shard -> reduce_scatter + if outputs.placements != output_layouts: + outputs = outputs.redistribute(placements=output_layouts, async_op=True) + # back to local tensor if use_local_output is True + return outputs.to_local() if use_local_output else outputs + + def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: + module._distribute_module_applied = True + if self.use_dtensor: + if isinstance(module, nn.Linear): + # rowwise linear runtime sharding requires input tensor shard on last dim + self.desired_input_layouts: Tuple[Placement, ...] = (Shard(-1),) + elif isinstance(module, nn.Embedding): + # rowwise embedding runtime sharding requires input tensor replicated + self.desired_input_layouts = (Replicate(),) + elif isinstance(module, nn.Parameter): + # rowwise embedding runtime sharding requires input tensor replicated + self.desired_input_layouts = (Shard(-1),) + else: + raise NotImplementedError("RowwiseParallel currently only support nn.Linear and nn.Embedding!") + + distribute_module( + module, + device_mesh, + partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts), + partial(self._prepare_output_fn, self.output_layouts, self.use_local_output), + ) + + +class PackedRowwiseParallel(RowwiseParallel): + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): + # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only) + # means Colwise as Linear is input * weight^T + bias, where + # weight would become Shard(1) + parameter = get_packed_weights(param, empty_param, device_mesh, rank, -1) + parameter = parameter.to(param_casting_dtype) + if to_contiguous: + parameter = parameter.contiguous() + if self.use_dtensor: + parameter = DTensor.from_local(parameter, device_mesh, [Shard(-1)], run_check=False) + return nn.Parameter(parameter) + + +SUPPORTED_TP_STYLES = { + "colwise", + "rowwise", + "colwise_rep", + "rowwise_rep", + "local_colwise", + "local_rowwise", + "local", + "gather", + "local_packed_rowwise", +} + + +@lru_cache +def translate_to_torch_parallel_style(style: str): + """ + In model configurations, we use a neutral type (string) to specify parallel + styles, here we translate them into torch.distributed tensor-parallel + types. + """ + if not isinstance(style, str): + raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") + + if style == "colwise": + return ColwiseParallel() + elif style == "rowwise": + return RowwiseParallel() + elif style == "colwise_rep": + return ColwiseParallel(output_layouts=Replicate()) + elif style == "rowwise_rep": + return RowwiseParallel(input_layouts=Replicate()) + elif style == "local_colwise": + return ColwiseParallel(use_dtensor=False) + elif style == "local_rowwise": + return RowwiseParallel(use_dtensor=False) + elif style == "local": + return IsolatedParallel() + elif style == "gather": + return GatherParallel() + elif style == "local_packed_rowwise": + return PackedRowwiseParallel(use_dtensor=False) + else: + raise ValueError(f"Unsupported parallel style value: {style}") + + +def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh): + """ + Add hooks to the module holding the layer. Meaning: + ``` + class MyModel(nn.Module): + def __init__(self): + self.layer = nn.Linear(10, 10) + ``` + has state_dict like: + ``` + { + "layer.weight": torch.Tensor, + "layer.bias": torch.Tensor + } + ``` + we add hooks to `MyModel` as well as `layer` to make sure that the tensors are correctly sharded and gathered. + """ + + # 1. We add hooks to the layer being loaded: + if current_module_plan is not None: + tp_layer = translate_to_torch_parallel_style(current_module_plan) + tp_layer.prepare_module_tp(module, device_mesh) + + # 2. We add hooks to the parrent module if needed + if "." in layer_name: + parrent_layer_name = layer_name.rsplit(".", 1)[0] + generic_name = re.sub(r"\d+", "*", parrent_layer_name) + # The module itself needs hooks + if module_plan := tp_plan.get(generic_name, False): + tp_layer = translate_to_torch_parallel_style(module_plan) + module_to_tp_ = model.get_submodule(parrent_layer_name) + tp_layer.prepare_module_tp(module_to_tp_, device_mesh) + + +def shard_and_distribute_module( + model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh +): + r""" + Main uses cases: + - column / rowise parallelism, you just shard all the weights of the layer (weight and bias) + - packed layers: you slice the weights, then shard like above + - custom operation: + - you want to add an all-gather at the end of a local layer. + - you want to have a layer that is isolated from the rest of the world (because torch.DTensor does not work well with `.view` for instance) + + """ + param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name + tp_plan = model._tp_plan + module_to_tp = model.get_submodule(param_name) + current_module_plan = None + generic_param_name = re.sub(r"\d+", "*", parameter_name) + if generic_param_name in tp_plan: + current_module_plan = tp_plan[generic_param_name] + elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan: + current_module_plan = tp_plan[generic_param_name.rsplit(".", 1)[0]] + + # Add hooks to the module if not done yet + # add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh) + if not getattr(module_to_tp, "_is_hooked", False): + add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh) + module_to_tp._is_hooked = True + + if current_module_plan is not None: + tp_layer = translate_to_torch_parallel_style(current_module_plan) + param = tp_layer.partition_tensor( + param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh + ) + else: + param = param[:] + if is_contiguous: + param = param.contiguous() + + # SUPER IMPORTANT we have to use setattr + # otherwise loading is crazy slow + if not isinstance(param, torch.nn.Parameter): + param = torch.nn.Parameter(param) + setattr(module_to_tp, param_type, param) + # module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True) + return param diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 29e55e01be..4a1a683c6b 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -54,17 +54,20 @@ from .integrations.deepspeed import _load_state_dict_into_zero3_model from .integrations.flash_attention import flash_attention_forward from .integrations.flex_attention import flex_attention_forward from .integrations.sdpa_attention import sdpa_attention_forward +from .integrations.tensor_parallel import ( + SUPPORTED_TP_STYLES, + shard_and_distribute_module, + translate_to_torch_parallel_style, +) from .loss.loss_utils import LOSS_MAPPING from .pytorch_utils import ( # noqa: F401 Conv1D, apply_chunking_to_forward, - distribute_module, find_pruneable_heads_and_indices, id_tensor_storage, prune_conv1d_layer, prune_layer, prune_linear_layer, - translate_to_torch_parallel_style, ) from .quantizers import AutoHfQuantizer, HfQuantizer from .quantizers.quantizers_utils import get_module_from_name @@ -151,6 +154,7 @@ logger = logging.get_logger(__name__) _init_weights = True _is_quantized = False _is_ds_init_called = False +_torch_distributed_available = torch.distributed.is_available() def is_fsdp_enabled(): @@ -181,8 +185,6 @@ else: if is_peft_available(): from .utils import find_adapter_config_file -if is_torch_greater_or_equal("2.5"): - from torch.distributed.tensor import DTensor, Shard SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel") @@ -756,6 +758,40 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix): setattr(submodule, param_name, new_val) +def fix_tensor_type_and_device( + model, param_name, param, dtype=None, keep_in_fp32_modules=None +) -> Union[str, torch.dtype]: + # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which + # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. + # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 + + old_param = model + if "." in param_name: + pre, _ = param_name.rsplit(".", 1) + + old_param = model.get_submodule(pre) + if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)): + old_param = None + + is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") + # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params + # in int/uint/bool and not cast them. + param_casting_dtype = None + is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn + if param.dtype.is_floating_point and not is_param_float8_e4m3fn: + if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(param_name): + param_casting_dtype = torch.float32 + elif dtype is not None: + param_casting_dtype = dtype + elif old_param is not None: + param_casting_dtype = old_param.dtype + return old_param is not None and old_param.is_contiguous(), param_casting_dtype + else: + return False, None + + return + + @torch.no_grad() def _load_state_dict_into_meta_model( model: torch.nn.Module, @@ -787,18 +823,12 @@ def _load_state_dict_into_meta_model( It also initialize tensor parallelism for each module if needed. """ - tensor_device = None + tensor_device = "cpu" if device_map is not None and device_map.get("", None) is not None: tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""] if device_map is not None: device_map_regex = "|".join(sorted(device_map.keys(), reverse=True)) - # we need this later to initialize tensor parallelism - if device_mesh is not None: - full_tp_plan = model.config.base_model_tp_plan - for submodule in model.modules(): - full_tp_plan.update(getattr(submodule, "_tp_plan", {})) - file_pointer = None bin_state_dict = None if shard_file.endswith(".safetensors"): @@ -818,8 +848,6 @@ def _load_state_dict_into_meta_model( is_quantized = hf_quantizer is not None - is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") - for serialized_param_name, empty_param in state_dict.items(): # serialized_param_name is the raw, serialized name # fixed_param_name is the model's equivalent @@ -829,87 +857,37 @@ def _load_state_dict_into_meta_model( continue # we need to use serialized_param_name as file pointer is untouched - param = ( - file_pointer.get_slice(serialized_param_name) - if shard_file.endswith(".safetensors") - else bin_state_dict[serialized_param_name] + if shard_file.endswith(".safetensors"): + param = file_pointer.get_slice(serialized_param_name) + elif shard_file.endswith(".gguf"): + param = empty_param # For gguf the dict is actually not empty! + else: + param = bin_state_dict[serialized_param_name] + + to_contiguous, param_casting_dtype = fix_tensor_type_and_device( + model, + param_name=fixed_param_name, + param=empty_param, + dtype=dtype, + keep_in_fp32_modules=keep_in_fp32_modules, ) - # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which - # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. - # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 - - old_param = model - splits = fixed_param_name.split(".") - for split in splits: - # We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys. - old_param = getattr(old_param, split, None) - if old_param is None: - break - - if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)): - old_param = None - - # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params - # in int/uint/bool and not cast them. - param_casting_dtype = None - is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn - if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: - if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(fixed_param_name): - param_casting_dtype = torch.float32 - elif dtype is not None: - param_casting_dtype = dtype - elif old_param is not None: - param_casting_dtype = old_param.dtype - if device_mesh is not None: # In this case, the param is already on the correct device! - module_to_tp, param_type = get_module_from_name(model, fixed_param_name) - current_module_plan = None - full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+") - if plan := re.search(full_tp_plan_, fixed_param_name): - match = re.sub("[0-9]+", "*", plan[0]) - current_module_plan = full_tp_plan[match] - - if current_module_plan is not None: - tp_layer = translate_to_torch_parallel_style(current_module_plan) - rank = tensor_device - row, col = empty_param.shape - if "rowwise" == current_module_plan: - param = param[:, rank * (col // device_mesh.size()) : (rank + 1) * (col // device_mesh.size())] - shard = Shard(1) - tp_layer.desired_input_layouts = (Shard(-1),) - elif "colwise" == current_module_plan: - param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :] - shard = Shard(0) - else: - param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :] - shard = Shard(0) - if param_casting_dtype is not None: - param = param.to(param_casting_dtype) - if old_param.is_contiguous(): - param = param.contiguous() - local_parameter = DTensor.from_local( - param, - device_mesh=device_mesh, - placements=[shard] * device_mesh.ndim, - ) - if isinstance(module_to_tp.weight, nn.Parameter): - local_parameter = torch.nn.Parameter(local_parameter) - module_to_tp.weight = local_parameter - input_fn = partial(tp_layer._prepare_input_fn, tp_layer.input_layouts, tp_layer.desired_input_layouts) - output_fn = partial(tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output) - distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn) - else: - param = param[:] - if old_param is not None and old_param.is_contiguous(): - param = param.contiguous() - module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True) - + shard_and_distribute_module( + model, + param, + empty_param, + fixed_param_name, + param_casting_dtype, + to_contiguous, + tensor_device, # the rank + device_mesh, + ) else: param = param[:] if param_casting_dtype is not None: param = param.to(param_casting_dtype) - if old_param is not None and old_param.is_contiguous(): + if to_contiguous: param = param.contiguous() if device_map is None: @@ -966,6 +944,7 @@ def _load_state_dict_into_meta_model( val_kwargs["requires_grad"] = False value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__) setattr(module, param_type, value) + if file_pointer is not None: file_pointer.__exit__(None, None, None) @@ -1409,7 +1388,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # A tensor parallel plan to be applied to the model when TP is enabled. For # top-level models, this attribute is currently defined in respective model # code. For base models, this attribute comes from - # `config.base_model_tp_plan` during `post_init`. + # `config.base_model_tp_plan` during `__init__`. + # It should identify the layers exactly: if you want to TP model.language_model.layers.fc1 + # by passing `tp_plan` to the init, it should be {"model.language_model.layers.fc1":"colwise"} + # for example. _tp_plan = None # A pipeline parallel plan specifying the layers which may not be present @@ -1475,6 +1457,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # when a different component (e.g. language_model) is used. self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules) + self._no_split_modules = self._no_split_modules or [] + def post_init(self): """ A method executed at the end of each Transformer model initialization, to execute code that needs the model's @@ -1482,11 +1466,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix """ self.init_weights() self._backward_compatibility_gradient_checkpointing() + # If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config if self.base_model is self: - self._tp_plan = self.config.base_model_tp_plan self._pp_plan = self.config.base_model_pp_plan + self._tp_plan = self._tp_plan or self.config.base_model_tp_plan or {} + for name, module in self.named_children(): + if plan := getattr(module, "_tp_plan", None): + self._tp_plan.update({f"{name}.{k}": v for k, v in plan.items()}) + + if self._tp_plan is not None and is_torch_greater_or_equal("2.3"): + for _, v in self._tp_plan.items(): + if v not in SUPPORTED_TP_STYLES: + raise ValueError( + f"Unsupported tensor parallel style {v}. Supported styles are {SUPPORTED_TP_STYLES}" + ) + def dequantize(self): """ Potentially dequantize the model in case it has been quantized by a quantization method that support @@ -4315,7 +4311,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix model = cls(config, *model_args, **model_kwargs) if device_mesh is not None and not model.supports_tp_plan: - raise NotImplementedError("This model does not have a tensor parallel plan.") + if config.base_model_tp_plan is None and config.get_text_config().base_model_tp_plan is None: + raise NotImplementedError("This model does not have a tensor parallel plan.") # make sure we use the model's config since the __init__ call might have copied it config = model.config @@ -4453,7 +4450,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix model, state_dict, loaded_state_dict_keys, # XXX: rename? - resolved_archive_file, + resolved_archive_file or gguf_file, pretrained_model_name_or_path, ignore_mismatched_sizes=ignore_mismatched_sizes, sharded_metadata=sharded_metadata, @@ -4565,7 +4562,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix @staticmethod def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]: """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight.""" - # Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert) # This rename is logged. if key.endswith("LayerNorm.beta"): @@ -4590,6 +4586,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix return key, False def rename_key(self, key): + """ + When we load a LlamaModel from a checkpoint made using LlamaForCausalLM, the keys have an extra + prefix, which can be accessed in the `LlamaModel` via the `self.base_model_prefix` attribute. + + But, what if there is an extra layer on top of it? You load a MistralModel from a LlavaForConditionalGeneration? + In that what you actually want is to cut whatever is left of the key. + """ new_key = key if len(self.base_model_prefix) > 0: if not hasattr(self, self.base_model_prefix) and key.startswith(self.base_model_prefix): @@ -4940,7 +4943,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix keep_in_fp32_modules=keep_in_fp32_modules, unexpected_keys=unexpected_keys, device_mesh=device_mesh, - resolved_archive_file=resolved_archive_file, + shard_file=resolved_archive_file, weights_only=weights_only, ) else: @@ -5019,7 +5022,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( model_to_load, state_dict, - start_prefix, + prefix, expected_keys, device_map=device_map, offload_folder=offload_folder, @@ -5898,10 +5901,21 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, for param_name, device in accelerator_device_map.items(): try: - param = model.get_parameter(param_name) + param = getattr(model, param_name) except AttributeError: - param = model.get_buffer(param_name) - parameter_count[device] += int(math.prod(param.shape) * allocation_factor) + if "." in param_name: + param_name, param_type = param_name.rsplit(".", 1) + param = getattr(model.get_submodule(param_name), param_type) + else: + param = model.get_buffer(param_name) + + param_size = int(math.prod(param.shape) * allocation_factor) + + if _torch_distributed_available and torch.distributed.is_initialized(): + generic_name = re.sub(r"\d+", "*", param_name) + param_size //= torch.distributed.get_world_size() if not model._tp_plan.get(generic_name, False) else 1 + + parameter_count[device] += param_size dtype = dtype if dtype is not None else torch.float32 diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 916631da7e..7a6475c152 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -419,7 +419,7 @@ class Blip2PreTrainedModel(PreTrainedModel): "OPTDecoderLayer", ] _skip_keys_device_placement = "past_key_values" - _keep_in_fp32_modules = ["wo"] + _keep_in_fp32_modules = ["query_tokens"] def _init_weights(self, module): """Initialize the weights""" @@ -1799,7 +1799,7 @@ class Blip2Model(Blip2PreTrainedModel): ) class Blip2TextModelWithProjection(Blip2PreTrainedModel): supports_gradient_checkpointing = False - _keep_in_fp32_modules = [] + _keep_in_fp32_modules = ["query_tokens"] def __init__(self, config: Blip2Config): super().__init__(config) @@ -1898,7 +1898,7 @@ class Blip2TextModelWithProjection(Blip2PreTrainedModel): ) class Blip2VisionModelWithProjection(Blip2PreTrainedModel): main_input_name = "pixel_values" - _keep_in_fp32_modules = [] + _keep_in_fp32_modules = ["query_tokens"] def __init__(self, config: Blip2Config): super().__init__(config) @@ -2371,7 +2371,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): ) class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): main_input_name = "pixel_values" - _keep_in_fp32_modules = [] + _keep_in_fp32_modules = ["query_tokens"] def __init__(self, config: Blip2Config): super().__init__(config) diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index ea42d65b84..d685dd6e99 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -322,7 +322,6 @@ class InstructBlipPreTrainedModel(PreTrainedModel): "InstructBlipQFormerMultiHeadAttention", "InstructBlipQFormerSelfOutput", ] - _keep_in_fp32_modules = [] # Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlip def _init_weights(self, module): @@ -1293,6 +1292,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati _supports_cache_class = True _supports_static_cache = True _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) + _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8 def __init__(self, config: InstructBlipConfig): super().__init__(config) diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index 5183a3c22f..8648d53b87 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -323,7 +323,6 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel): "InstructBlipVideoQFormerMultiHeadAttention", "InstructBlipVideoQFormerSelfOutput", ] - _keep_in_fp32_modules = [] def _init_weights(self, module): """Initialize the weights""" @@ -1287,6 +1286,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel _supports_cache_class = True _supports_static_cache = True _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) + _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8 def __init__(self, config: InstructBlipVideoConfig): super().__init__(config) diff --git a/tests/models/llava/test_configuration_llava.py b/tests/models/llava/test_configuration_llava.py index 458743887d..3b28adc1ee 100644 --- a/tests/models/llava/test_configuration_llava.py +++ b/tests/models/llava/test_configuration_llava.py @@ -58,13 +58,13 @@ class LlavaConfigTest(unittest.TestCase): """ Simple test for reloading arbirarily composed subconfigs """ - default_values = LlavaConfig().to_dict() - default_values["vision_config"]["model_type"] = "qwen2_vl" + default_values = LlavaConfig().to_diff_dict() + default_values["vision_config"]["model_type"] = "pixtral" default_values["text_config"]["model_type"] = "opt" - + self.maxDiff = None with tempfile.TemporaryDirectory() as tmp_dir: config = LlavaConfig(**default_values) config.save_pretrained(tmp_dir) reloaded = LlavaConfig.from_pretrained(tmp_dir) - assert config.to_dict() == reloaded.to_dict() + self.assertDictEqual(config.to_dict(), reloaded.to_dict()) diff --git a/tests/tp/test_tp.py b/tests/tensor_parallel/test_tensor_parallel.py similarity index 100% rename from tests/tp/test_tp.py rename to tests/tensor_parallel/test_tensor_parallel.py