Add ep (#39501)
* EP + updates Co-authored-by: Nouamane Tazi <NouamaneTazi@users.noreply.github.com> Co-authored-by: drbh <drbh@users.noreply.github.com> * remove unrelated change * not working yet but let's see where it goes! * update the api a bit * udpate * where I am at for now * fix ep * refactor the API * yups * fix * fixup * clean modeling * just support llama4 for now! * properly avoid * fix * nits * Update src/transformers/models/llama4/modeling_llama4.py * Update src/transformers/integrations/tensor_parallel.py * style * ,,,, * update --------- Co-authored-by: Nouamane Tazi <NouamaneTazi@users.noreply.github.com> Co-authored-by: drbh <drbh@users.noreply.github.com>
This commit is contained in:
33
src/transformers/distributed/__init__.py
Normal file
33
src/transformers/distributed/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ..utils import _LazyModule
|
||||||
|
|
||||||
|
|
||||||
|
_import_structure = {
|
||||||
|
"configuration_utils": ["DistributedConfig"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_utils import (
|
||||||
|
DistributedConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||||
111
src/transformers/distributed/configuration_utils.py
Normal file
111
src/transformers/distributed/configuration_utils.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DistributedConfig:
|
||||||
|
"""
|
||||||
|
Base class for distributed configs
|
||||||
|
"""
|
||||||
|
|
||||||
|
enable_expert_parallel: bool = False
|
||||||
|
# TODO: add tp_plan, pp_plan, device_mesh etc..
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, config_dict, **kwargs):
|
||||||
|
"""
|
||||||
|
Constructs a DistributedConfig instance from a dictionary of parameters.
|
||||||
|
Args:
|
||||||
|
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
|
||||||
|
**kwargs: Additional keyword arguments to override dictionary values.
|
||||||
|
Returns:
|
||||||
|
DistributedConfig: Instance of DistributedConfig constructed from the dictionary.
|
||||||
|
"""
|
||||||
|
config = cls(**config_dict)
|
||||||
|
to_remove = []
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(config, key):
|
||||||
|
setattr(config, key, value)
|
||||||
|
to_remove.append(key)
|
||||||
|
for key in to_remove:
|
||||||
|
kwargs.pop(key, None)
|
||||||
|
return config
|
||||||
|
|
||||||
|
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
|
||||||
|
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
||||||
|
"""
|
||||||
|
Save this instance to a JSON file.
|
||||||
|
Args:
|
||||||
|
json_file_path (`str` or `os.PathLike`):
|
||||||
|
Path to the JSON file in which this configuration instance's parameters will be saved.
|
||||||
|
use_diff (`bool`, *optional*, defaults to `True`):
|
||||||
|
If set to `True`, only the difference between the config instance and the default
|
||||||
|
`QuantizationConfig()` is serialized to JSON file.
|
||||||
|
"""
|
||||||
|
with open(json_file_path, "w", encoding="utf-8") as writer:
|
||||||
|
config_dict = self.to_dict()
|
||||||
|
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||||
|
|
||||||
|
writer.write(json_string)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Serializes this instance to a Python dictionary. Returns:
|
||||||
|
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
||||||
|
"""
|
||||||
|
return copy.deepcopy(self.__dict__)
|
||||||
|
|
||||||
|
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
|
||||||
|
def __iter__(self):
|
||||||
|
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
|
||||||
|
for attr, value in copy.deepcopy(self.__dict__).items():
|
||||||
|
yield attr, value
|
||||||
|
|
||||||
|
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||||
|
|
||||||
|
def to_json_string(self):
|
||||||
|
"""
|
||||||
|
Serializes this instance to a JSON formatted string.
|
||||||
|
Returns:
|
||||||
|
str: JSON formatted string representing the configuration instance.
|
||||||
|
"""
|
||||||
|
return json.dumps(self.__dict__, indent=2) + "\n"
|
||||||
|
|
||||||
|
def update(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
|
||||||
|
returning all the unused kwargs.
|
||||||
|
Args:
|
||||||
|
kwargs (`Dict[str, Any]`):
|
||||||
|
Dictionary of attributes to tentatively update this class.
|
||||||
|
Returns:
|
||||||
|
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
|
||||||
|
"""
|
||||||
|
to_remove = []
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(self, key):
|
||||||
|
setattr(self, key, value)
|
||||||
|
to_remove.append(key)
|
||||||
|
|
||||||
|
# Remove all the attributes that were updated, without modifying the input dict
|
||||||
|
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
|
||||||
|
return unused_kwargs
|
||||||
@@ -52,6 +52,12 @@ try:
|
|||||||
layer_name="TritonLlamaMLP",
|
layer_name="TritonLlamaMLP",
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
"MegaBlocksMoeMLP": {
|
||||||
|
"cuda": LayerRepository(
|
||||||
|
repo_id="kernels-community/megablocks",
|
||||||
|
layer_name="MegaBlocksMoeMLP",
|
||||||
|
)
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
register_kernel_mapping(_KERNEL_MAPPING)
|
register_kernel_mapping(_KERNEL_MAPPING)
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from ..distributed import DistributedConfig
|
||||||
from ..utils import is_torch_greater_or_equal, logging
|
from ..utils import is_torch_greater_or_equal, logging
|
||||||
from ..utils.generic import GeneralInterface
|
from ..utils.generic import GeneralInterface
|
||||||
|
|
||||||
@@ -90,7 +91,7 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None):
|
|||||||
device_map = tp_device
|
device_map = tp_device
|
||||||
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
|
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
|
||||||
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
|
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
|
||||||
return tp_device, device_map, device_mesh
|
return tp_device, device_map, device_mesh, tp_size
|
||||||
|
|
||||||
|
|
||||||
def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int]:
|
def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int]:
|
||||||
@@ -119,20 +120,22 @@ def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int
|
|||||||
return [single_size] * blocks
|
return [single_size] * blocks
|
||||||
|
|
||||||
|
|
||||||
def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str]) -> str | None:
|
def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weight=True) -> str | None:
|
||||||
"""
|
"""
|
||||||
Get the TP style for a parameter from the TP plan.
|
Get the TP style for a parameter from the TP plan.
|
||||||
|
|
||||||
The TP plan is a dictionary that maps parameter names to TP styles.
|
The TP plan is a dictionary that maps parameter names to TP styles.
|
||||||
The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight").
|
The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight").
|
||||||
|
|
||||||
|
The `is_weight` is important because for weights, we want to support `.weights` and `.bias` cases seamlessly! but
|
||||||
|
not parrent classes for `post_init` calls
|
||||||
"""
|
"""
|
||||||
generic_param_name = re.sub(r"\d+", "*", parameter_name)
|
generic_param_name = re.sub(r"\d+", "*", parameter_name)
|
||||||
if generic_param_name in tp_plan:
|
if generic_param_name in tp_plan:
|
||||||
return tp_plan[generic_param_name]
|
return tp_plan[generic_param_name]
|
||||||
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
|
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan and is_weight:
|
||||||
return tp_plan[generic_param_name.rsplit(".", 1)[0]]
|
return tp_plan[generic_param_name.rsplit(".", 1)[0]]
|
||||||
else:
|
return None
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
str_to_torch_dtype = {
|
str_to_torch_dtype = {
|
||||||
@@ -198,8 +201,10 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
|
|||||||
slice_dtype = slice_.get_dtype()
|
slice_dtype = slice_.get_dtype()
|
||||||
# Handle F8_E4M3 dtype by converting to float16 before slicing
|
# Handle F8_E4M3 dtype by converting to float16 before slicing
|
||||||
# Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn'
|
# Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn'
|
||||||
if slice_dtype == "F8_E4M3":
|
casted = False
|
||||||
|
if slice_dtype == "F8_E4M3" or slice_dtype == "F8_E5M2":
|
||||||
slice_ = slice_[...].to(torch.float16)
|
slice_ = slice_[...].to(torch.float16)
|
||||||
|
casted = True
|
||||||
|
|
||||||
if dim == 0:
|
if dim == 0:
|
||||||
tensor = slice_[tensors_slices, ...]
|
tensor = slice_[tensors_slices, ...]
|
||||||
@@ -209,7 +214,11 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
|
|||||||
tensor = slice_[..., tensors_slices]
|
tensor = slice_[..., tensors_slices]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
|
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
|
||||||
return tensor.to(str_to_torch_dtype[slice_dtype])
|
|
||||||
|
if casted:
|
||||||
|
return tensor
|
||||||
|
else:
|
||||||
|
return tensor.to(str_to_torch_dtype[slice_dtype])
|
||||||
|
|
||||||
|
|
||||||
def repack_weights(
|
def repack_weights(
|
||||||
@@ -423,16 +432,27 @@ class GatherParallel(TensorParallelLayer):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||||
|
mod.expert_parallel_group = device_mesh.get_group()
|
||||||
if inputs and isinstance(inputs[0], DTensor):
|
if inputs and isinstance(inputs[0], DTensor):
|
||||||
inputs = inputs[0].to_local()
|
inputs = inputs[0].to_local()
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||||
# this op cannot be async, otherwise it completely breaks the outputs of models
|
if isinstance(outputs, torch.Tensor):
|
||||||
torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False)
|
dist.all_reduce(outputs, op=dist.ReduceOp.SUM, async_op=False)
|
||||||
|
else:
|
||||||
|
dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
||||||
|
distribute_module(
|
||||||
|
module,
|
||||||
|
device_mesh,
|
||||||
|
partial(self._prepare_input_fn, None, None),
|
||||||
|
partial(self._prepare_output_fn, None, None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class IsolatedParallel(TensorParallelLayer):
|
class IsolatedParallel(TensorParallelLayer):
|
||||||
"""
|
"""
|
||||||
@@ -453,6 +473,14 @@ class IsolatedParallel(TensorParallelLayer):
|
|||||||
# TODO: figure out dynamo support for instance method and switch this to instance method
|
# TODO: figure out dynamo support for instance method and switch this to instance method
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||||
|
param = param[...].to(param_casting_dtype)
|
||||||
|
if to_contiguous:
|
||||||
|
param = param.contiguous()
|
||||||
|
param = param / device_mesh.size() # TODO should be optionable
|
||||||
|
# TODO: assumes parent module will allreduce the output afterwards (e.g rowlinear bias is IsolatedParallel and parent module is GatherParallel)
|
||||||
|
return param
|
||||||
|
|
||||||
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
||||||
distribute_module(
|
distribute_module(
|
||||||
module,
|
module,
|
||||||
@@ -773,6 +801,108 @@ class SequenceParallel(TensorParallelLayer):
|
|||||||
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
||||||
|
|
||||||
|
|
||||||
|
class GroupedGemmParallel(TensorParallelLayer):
|
||||||
|
"""
|
||||||
|
Applies Expert Parallelism to MoE experts by loading the correct experts on each device.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.use_dtensor = False
|
||||||
|
|
||||||
|
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||||
|
ep_rank = rank
|
||||||
|
global_num_experts = empty_param.shape[0]
|
||||||
|
if global_num_experts % device_mesh.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0"
|
||||||
|
)
|
||||||
|
local_num_experts = global_num_experts // device_mesh.size()
|
||||||
|
param = param[ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts].to(param_casting_dtype)
|
||||||
|
if to_contiguous:
|
||||||
|
param = param.contiguous()
|
||||||
|
if "gate_up" in param_type and False:
|
||||||
|
param = torch.cat([param[..., ::2], param[..., 1::2]], dim=-1)
|
||||||
|
return param
|
||||||
|
|
||||||
|
|
||||||
|
class RouterParallel(TensorParallelLayer):
|
||||||
|
"""
|
||||||
|
Allows to reshape the router scores to support running expert parallel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.use_dtensor = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||||
|
input_tensor = inputs[0]
|
||||||
|
if isinstance(input_tensor, DTensor):
|
||||||
|
raise NotImplementedError("RouterParallel does not support DTensor input for now")
|
||||||
|
return input_tensor
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||||
|
"""
|
||||||
|
Imagine if you had 4 tokens, top_k = 4, and 128experts.
|
||||||
|
With EP = 8.
|
||||||
|
Imagine router_indices being:
|
||||||
|
[ 52, 42, 119, 67],
|
||||||
|
[102, 89, 61, 40],
|
||||||
|
[ 82, 103, 4, 34],
|
||||||
|
[ 93, 23, 109, 11],
|
||||||
|
|
||||||
|
then you can map which rank should be getting which values
|
||||||
|
|
||||||
|
[3, 2, 7, 4],
|
||||||
|
[6, 5, 3, 2],
|
||||||
|
[5, 6, 0, 2],
|
||||||
|
[5, 1, 6, 0],
|
||||||
|
|
||||||
|
Thus for say rank 0, you fill with 0 the index tensor
|
||||||
|
|
||||||
|
[ 0, 0, 0, 0],
|
||||||
|
[ 0, 0, 0, 0],
|
||||||
|
[ 0, 0, 4, 0],
|
||||||
|
[ 0, 0, 0, 11],
|
||||||
|
|
||||||
|
This works well. For another rank you need to make sure you round to num_local_expert
|
||||||
|
because the next operation will one hot encode the router index vector.
|
||||||
|
|
||||||
|
This allows us to know directly which local expert is hit.
|
||||||
|
Similarly the scores are indexed with something created form
|
||||||
|
router_indices.
|
||||||
|
|
||||||
|
The kinda naive training loop that we use for device_map "auto" uses a similar logic.
|
||||||
|
Here we are just making each rank believe that he is alone, and he computes his part of the hiddenstates.
|
||||||
|
"""
|
||||||
|
ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size()
|
||||||
|
num_local_experts = mod.num_experts // ep_size
|
||||||
|
router_scores, router_indices = outputs
|
||||||
|
router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts]
|
||||||
|
router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, 0)
|
||||||
|
router_indices = router_indices % num_local_experts
|
||||||
|
return router_scores, router_indices
|
||||||
|
|
||||||
|
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||||
|
# TODO: i'd like for this to be the default
|
||||||
|
param = param[...].to(param_casting_dtype)
|
||||||
|
if to_contiguous:
|
||||||
|
param = param.contiguous()
|
||||||
|
return param
|
||||||
|
|
||||||
|
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
||||||
|
# TODO: need an abstract Parallel class that is different from TensorParallelLayer
|
||||||
|
distribute_module(
|
||||||
|
module,
|
||||||
|
device_mesh,
|
||||||
|
partial(self._prepare_input_fn, None, None),
|
||||||
|
partial(self._prepare_output_fn, None, None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ParallelInterface(GeneralInterface):
|
class ParallelInterface(GeneralInterface):
|
||||||
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
|
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
|
||||||
# a new instance is created (in order to locally override a given entry)
|
# a new instance is created (in order to locally override a given entry)
|
||||||
@@ -789,6 +919,8 @@ class ParallelInterface(GeneralInterface):
|
|||||||
"local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
|
"local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
|
||||||
"sequence_parallel": SequenceParallel(),
|
"sequence_parallel": SequenceParallel(),
|
||||||
"replicate": ReplicateParallel(),
|
"replicate": ReplicateParallel(),
|
||||||
|
"grouped_gemm": GroupedGemmParallel(),
|
||||||
|
"ep_router": RouterParallel(),
|
||||||
}
|
}
|
||||||
if is_torch_greater_or_equal("2.5") and _torch_distributed_available
|
if is_torch_greater_or_equal("2.5") and _torch_distributed_available
|
||||||
else {}
|
else {}
|
||||||
@@ -841,25 +973,17 @@ def replace_state_dict_local_with_dtensor(
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh):
|
def add_tensor_parallel_hooks_to_module(
|
||||||
"""
|
model, module, tp_plan, layer_name, current_module_plan, device_mesh, parameter_name=None
|
||||||
Add hooks to the module holding the layer. Meaning:
|
):
|
||||||
```
|
r"""
|
||||||
class MyModel(nn.Module):
|
This function is called in `PretrainedModel.post_init()`. It is responsible of adding hooks
|
||||||
def __init__(self):
|
to the modules of the `model`, based on the `PretrainedModel._tp_plan`.
|
||||||
self.layer = nn.Linear(10, 10)
|
|
||||||
```
|
|
||||||
has state_dict like:
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"layer.weight": torch.Tensor,
|
|
||||||
"layer.bias": torch.Tensor
|
|
||||||
}
|
|
||||||
```
|
|
||||||
we add hooks to `MyModel` as well as `layer` to make sure that the tensors are correctly sharded and gathered.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 1. We add hooks to the layer being loaded:
|
This is the place where we add the `pre_forward` and `post_forwards` hooks. These are defined
|
||||||
|
for each `TensorParallelLayer` as `_prepare_input_fn` and `_prepare_output_fn`.
|
||||||
|
|
||||||
|
"""
|
||||||
if current_module_plan is not None:
|
if current_module_plan is not None:
|
||||||
tp_layer = ALL_PARALLEL_STYLES[current_module_plan]
|
tp_layer = ALL_PARALLEL_STYLES[current_module_plan]
|
||||||
try:
|
try:
|
||||||
@@ -868,26 +992,19 @@ def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, curr
|
|||||||
print(
|
print(
|
||||||
f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}"
|
f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
module._hf_tp_plan = current_module_plan
|
module._hf_tp_plan = current_module_plan
|
||||||
module.__repr__ = lambda: f"{module.__repr__()}\nTP Plan: {current_module_plan}"
|
module.__repr__ = lambda: f"{module.__repr__()}\nTP Plan: {current_module_plan}"
|
||||||
|
|
||||||
# 2. We add hooks to the parent module if needed
|
|
||||||
if "." in layer_name:
|
|
||||||
parent_layer_name = layer_name.rsplit(".", 1)[0]
|
|
||||||
generic_name = re.sub(r"\d+", "*", parent_layer_name)
|
|
||||||
# The module itself needs hooks
|
|
||||||
if module_plan := tp_plan.get(generic_name, False):
|
|
||||||
tp_layer = ALL_PARALLEL_STYLES[module_plan]
|
|
||||||
module_to_tp_ = model.get_submodule(parent_layer_name)
|
|
||||||
tp_layer.prepare_module_tp(module_to_tp_, device_mesh)
|
|
||||||
module_to_tp_._hf_tp_plan = current_module_plan
|
|
||||||
module_to_tp_.__repr__ = lambda: f"{module_to_tp_.__repr__()}\nTP Plan: {current_module_plan}"
|
|
||||||
|
|
||||||
|
|
||||||
def shard_and_distribute_module(
|
def shard_and_distribute_module(
|
||||||
model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh
|
model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh
|
||||||
):
|
): # TODO: rename to shard_and_distribute_param
|
||||||
r"""
|
r"""
|
||||||
|
This function is called in `from_pretrained` when loading a model's checkpoints.
|
||||||
|
It receives the pointer to the parameter (or the parameter itself) and takes care of "sharding".
|
||||||
|
All process run this function, so they just load the partition of the tensor that they require.
|
||||||
|
|
||||||
Main uses cases:
|
Main uses cases:
|
||||||
- column / rowise parallelism, you just shard all the weights of the layer (weight and bias)
|
- column / rowise parallelism, you just shard all the weights of the layer (weight and bias)
|
||||||
- packed layers: you slice the weights, then shard like above
|
- packed layers: you slice the weights, then shard like above
|
||||||
@@ -898,39 +1015,33 @@ def shard_and_distribute_module(
|
|||||||
"""
|
"""
|
||||||
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
|
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
|
||||||
tp_plan = model._tp_plan
|
tp_plan = model._tp_plan
|
||||||
module_to_tp = model.get_submodule(param_name)
|
module_to_tp = model.get_submodule(param_name) # TODO: can i loop over modules?
|
||||||
rank = int(rank)
|
rank = int(rank)
|
||||||
|
current_shard_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
|
||||||
|
|
||||||
current_module_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
|
if dist.get_rank() == 0:
|
||||||
|
if current_shard_plan is None:
|
||||||
|
logger.info(f"Tensor sharding plan for {param_name} not found, using default 'replicate' plan.")
|
||||||
|
else:
|
||||||
|
logger.info(f"Tensor sharding plan for {param_name}: {current_shard_plan}")
|
||||||
|
|
||||||
if current_module_plan is None:
|
if current_shard_plan is not None:
|
||||||
current_module_plan = "replicate"
|
try:
|
||||||
if dist.get_rank() == 0:
|
tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
|
||||||
logger.info(f"Tensor parallel plan for {param_name} not found, using default 'replicate' plan.")
|
param = tp_layer.partition_tensor(
|
||||||
|
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
|
||||||
|
)
|
||||||
|
except NotImplementedError as e:
|
||||||
|
print(
|
||||||
|
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if dist.get_rank() == 0:
|
param = param[:].to(param_casting_dtype)
|
||||||
logger.info(f"Tensor parallel plan for {param_name}: {current_module_plan}")
|
|
||||||
|
|
||||||
# Add hooks to the module if not done yet
|
|
||||||
# add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
|
|
||||||
if not getattr(module_to_tp, "_is_hooked", False):
|
|
||||||
add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
|
|
||||||
module_to_tp._is_hooked = True
|
|
||||||
|
|
||||||
try:
|
|
||||||
tp_layer = ALL_PARALLEL_STYLES[current_module_plan]
|
|
||||||
param = tp_layer.partition_tensor(
|
|
||||||
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
|
|
||||||
)
|
|
||||||
except NotImplementedError as e:
|
|
||||||
print(
|
|
||||||
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# SUPER IMPORTANT we have to use setattr
|
# SUPER IMPORTANT we have to use setattr
|
||||||
# otherwise loading is crazy slow
|
# otherwise loading is crazy slow
|
||||||
if not isinstance(param, torch.nn.Parameter):
|
if not isinstance(param, torch.nn.Parameter):
|
||||||
param = torch.nn.Parameter(param, requires_grad=param.is_floating_point())
|
param = torch.nn.Parameter(param, requires_grad=empty_param.is_floating_point())
|
||||||
setattr(module_to_tp, param_type, param)
|
setattr(module_to_tp, param_type, param)
|
||||||
# module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)
|
# module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)
|
||||||
return param
|
return param
|
||||||
@@ -965,3 +1076,43 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
|
|||||||
logger.warning(f"The following TP rules were not applied on any of the layers: {unused_rules}")
|
logger.warning(f"The following TP rules were not applied on any of the layers: {unused_rules}")
|
||||||
if len(unsharded_layers) > 0:
|
if len(unsharded_layers) > 0:
|
||||||
logger.warning(f"The following layers were not sharded: {', '.join(unsharded_layers)}")
|
logger.warning(f"The following layers were not sharded: {', '.join(unsharded_layers)}")
|
||||||
|
|
||||||
|
|
||||||
|
def distribute_model(model, distributed_config, device_mesh, tp_size):
|
||||||
|
_plan = "_tp_plan"
|
||||||
|
model._tp_plan = getattr(model.config, "base_model_tp_plan").copy()
|
||||||
|
if distributed_config is not None:
|
||||||
|
distributed_config = DistributedConfig.from_config(distributed_config)
|
||||||
|
if distributed_config.enable_expert_parallel:
|
||||||
|
_plan = "_ep_plan"
|
||||||
|
model._tp_plan = getattr(model.config, "base_model_ep_plan", model._tp_plan).copy()
|
||||||
|
|
||||||
|
# now fetch my childrens
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if plan := getattr(module, _plan, getattr(module, "tp_plan", None)):
|
||||||
|
model._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
|
||||||
|
if hasattr(module, "config"):
|
||||||
|
plan = getattr(module.config, f"base_model{_plan}", {})
|
||||||
|
if plan == {}:
|
||||||
|
plan = getattr(module.config, "base_model_tp_plan", {})
|
||||||
|
model._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
|
||||||
|
|
||||||
|
if model._tp_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
||||||
|
for v in model._tp_plan.values():
|
||||||
|
if v not in ALL_PARALLEL_STYLES:
|
||||||
|
raise ValueError(f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}")
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if not getattr(module, "_is_hooked", False):
|
||||||
|
from transformers.integrations.tensor_parallel import add_tensor_parallel_hooks_to_module
|
||||||
|
|
||||||
|
plan = _get_parameter_tp_plan(parameter_name=name, tp_plan=model._tp_plan, is_weight=False)
|
||||||
|
add_tensor_parallel_hooks_to_module(
|
||||||
|
model=model,
|
||||||
|
module=module,
|
||||||
|
tp_plan=model._tp_plan,
|
||||||
|
layer_name="",
|
||||||
|
current_module_plan=plan,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
)
|
||||||
|
module._is_hooked = True
|
||||||
|
return model
|
||||||
|
|||||||
@@ -63,8 +63,8 @@ from .integrations.flex_attention import flex_attention_forward
|
|||||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||||
from .integrations.sdpa_paged import sdpa_attention_paged_forward
|
from .integrations.sdpa_paged import sdpa_attention_paged_forward
|
||||||
from .integrations.tensor_parallel import (
|
from .integrations.tensor_parallel import (
|
||||||
ALL_PARALLEL_STYLES,
|
|
||||||
_get_parameter_tp_plan,
|
_get_parameter_tp_plan,
|
||||||
|
distribute_model,
|
||||||
initialize_tensor_parallelism,
|
initialize_tensor_parallelism,
|
||||||
repack_weights,
|
repack_weights,
|
||||||
replace_state_dict_local_with_dtensor,
|
replace_state_dict_local_with_dtensor,
|
||||||
@@ -2218,6 +2218,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
"""
|
"""
|
||||||
A method executed at the end of each Transformer model initialization, to execute code that needs the model's
|
A method executed at the end of each Transformer model initialization, to execute code that needs the model's
|
||||||
modules properly initialized (such as weight initialization).
|
modules properly initialized (such as weight initialization).
|
||||||
|
|
||||||
|
This is also used when the user is running distributed code. We add hooks to the modules here, according to
|
||||||
|
the model's tp_plan!
|
||||||
"""
|
"""
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
self._backward_compatibility_gradient_checkpointing()
|
self._backward_compatibility_gradient_checkpointing()
|
||||||
@@ -2250,17 +2253,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
|
|
||||||
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
||||||
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
|
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
|
||||||
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
|
|
||||||
for name, module in self.named_children():
|
|
||||||
if plan := getattr(module, "_tp_plan", None):
|
|
||||||
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
|
|
||||||
|
|
||||||
if self._tp_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
|
||||||
for v in self._tp_plan.values():
|
|
||||||
if v not in ALL_PARALLEL_STYLES:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def dequantize(self):
|
def dequantize(self):
|
||||||
"""
|
"""
|
||||||
@@ -4568,6 +4560,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||||
load_in_4bit = kwargs.pop("load_in_4bit", False)
|
load_in_4bit = kwargs.pop("load_in_4bit", False)
|
||||||
quantization_config = kwargs.pop("quantization_config", None)
|
quantization_config = kwargs.pop("quantization_config", None)
|
||||||
|
distributed_config = kwargs.pop("distributed_config", None)
|
||||||
subfolder = kwargs.pop("subfolder", "")
|
subfolder = kwargs.pop("subfolder", "")
|
||||||
commit_hash = kwargs.pop("_commit_hash", None)
|
commit_hash = kwargs.pop("_commit_hash", None)
|
||||||
variant = kwargs.pop("variant", None)
|
variant = kwargs.pop("variant", None)
|
||||||
@@ -4588,6 +4581,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
):
|
):
|
||||||
key_mapping = cls._checkpoint_conversion_mapping
|
key_mapping = cls._checkpoint_conversion_mapping
|
||||||
|
|
||||||
|
if distributed_config is not None:
|
||||||
|
tp_plan = "auto"
|
||||||
|
|
||||||
# Not used anymore -- remove them from the kwargs
|
# Not used anymore -- remove them from the kwargs
|
||||||
_ = kwargs.pop("resume_download", None)
|
_ = kwargs.pop("resume_download", None)
|
||||||
_ = kwargs.pop("mirror", None)
|
_ = kwargs.pop("mirror", None)
|
||||||
@@ -4619,16 +4615,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
# `device_map` pointing to the correct device
|
# `device_map` pointing to the correct device
|
||||||
if tp_plan is not None:
|
if tp_plan is not None:
|
||||||
if device_mesh is None:
|
if device_mesh is None:
|
||||||
tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None)
|
tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(tp_plan, tp_size=tp_size)
|
||||||
else:
|
else:
|
||||||
if "tp" not in device_mesh.mesh_dim_names:
|
# TODO: make device_mesh support multiple dimensions
|
||||||
raise ValueError(
|
if device_mesh.ndim > 1:
|
||||||
"When using `tp_plan`, the `device_mesh` must contain a 'tp' dimension. "
|
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
|
||||||
"Please provide a valid `device_mesh`."
|
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
|
||||||
)
|
|
||||||
device_mesh = device_mesh["tp"]
|
|
||||||
tp_size = device_mesh["tp"].size()
|
|
||||||
device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
|
|
||||||
|
|
||||||
if tp_size is None:
|
if tp_size is None:
|
||||||
tp_size = torch.distributed.get_world_size()
|
tp_size = torch.distributed.get_world_size()
|
||||||
@@ -4928,23 +4920,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
)
|
)
|
||||||
|
|
||||||
config.name_or_path = pretrained_model_name_or_path
|
config.name_or_path = pretrained_model_name_or_path
|
||||||
|
|
||||||
# Instantiate model.
|
|
||||||
model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
|
model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
|
||||||
|
|
||||||
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
|
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
|
||||||
with ContextManagers(model_init_context):
|
with ContextManagers(model_init_context):
|
||||||
# Let's make sure we don't run the init function of buffer modules
|
# Let's make sure we don't run the init function of buffer modules
|
||||||
model = cls(config, *model_args, **model_kwargs)
|
model = cls(config, *model_args, **model_kwargs)
|
||||||
|
|
||||||
|
if _torch_distributed_available and device_mesh is not None:
|
||||||
|
model = distribute_model(model, distributed_config, device_mesh, tp_size)
|
||||||
|
|
||||||
# Make sure to tie the weights correctly
|
# Make sure to tie the weights correctly
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
|
|
||||||
# Last check for tp
|
|
||||||
if device_mesh is not None and not model.supports_tp_plan:
|
|
||||||
if config.base_model_tp_plan is None and config.get_text_config().base_model_tp_plan is None:
|
|
||||||
raise NotImplementedError("This model does not have a tensor parallel plan.")
|
|
||||||
|
|
||||||
# make sure we use the model's config since the __init__ call might have copied it
|
# make sure we use the model's config since the __init__ call might have copied it
|
||||||
config = model.config
|
config = model.config
|
||||||
|
|
||||||
@@ -5025,11 +5012,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
key_mapping=key_mapping,
|
key_mapping=key_mapping,
|
||||||
weights_only=weights_only,
|
weights_only=weights_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
# record tp degree the model sharded to
|
|
||||||
model._tp_size = tp_size
|
|
||||||
model._device_mesh = device_mesh
|
|
||||||
|
|
||||||
# make sure token embedding weights are still tied if needed
|
# make sure token embedding weights are still tied if needed
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
|
|
||||||
|
|||||||
@@ -265,6 +265,19 @@ class Llama4TextConfig(PretrainedConfig):
|
|||||||
"layers.*.feed_forward.down_proj": "local_rowwise",
|
"layers.*.feed_forward.down_proj": "local_rowwise",
|
||||||
"layers.*.feed_forward": "gather",
|
"layers.*.feed_forward": "gather",
|
||||||
}
|
}
|
||||||
|
base_model_ep_plan = {
|
||||||
|
"layers.*.self_attn.q_proj": "colwise",
|
||||||
|
"layers.*.self_attn.k_proj": "colwise",
|
||||||
|
"layers.*.self_attn.v_proj": "colwise",
|
||||||
|
"layers.*.self_attn.o_proj": "rowwise",
|
||||||
|
"layers.*.feed_forward.experts.gate_up_proj": "grouped_gemm", # row because not linear
|
||||||
|
"layers.*.feed_forward.experts.down_proj": "grouped_gemm", # col because not linear
|
||||||
|
"layers.*.feed_forward.experts": "gather", # all reduce
|
||||||
|
"layers.*.feed_forward.gate_proj": "local_colwise",
|
||||||
|
"layers.*.feed_forward.up_proj": "local_colwise",
|
||||||
|
"layers.*.feed_forward.down_proj": "local_rowwise",
|
||||||
|
"layers.*.feed_forward.router": "ep_router",
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache
|
from ...cache_utils import Cache, DynamicCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...integrations.hub_kernels import use_kernel_forward_from_hub
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...masking_utils import create_causal_mask, create_chunked_causal_mask
|
from ...masking_utils import create_causal_mask, create_chunked_causal_mask
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_layers import GradientCheckpointingLayer
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
@@ -35,6 +35,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
||||||
|
from ...utils.generic import check_model_inputs
|
||||||
from .configuration_llama4 import Llama4Config, Llama4TextConfig
|
from .configuration_llama4 import Llama4Config, Llama4TextConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -65,7 +66,7 @@ class Llama4TextExperts(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
"""
|
"""
|
||||||
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
|
hidden_states = hidden_states.view(self.gate_up_proj.shape[0], -1, self.hidden_size)
|
||||||
gate_up = torch.bmm(hidden_states, self.gate_up_proj)
|
gate_up = torch.bmm(hidden_states, self.gate_up_proj)
|
||||||
gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
|
gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
|
||||||
next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj)
|
next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj)
|
||||||
@@ -127,6 +128,20 @@ class Llama4TextRMSNorm(nn.Module):
|
|||||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||||
|
|
||||||
|
|
||||||
|
class Llama4Router(nn.Linear):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config.hidden_size, config.num_local_experts, bias=False)
|
||||||
|
self.num_experts = config.num_local_experts
|
||||||
|
self.top_k = config.num_experts_per_tok
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
router_logits = super().forward(hidden_states)
|
||||||
|
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
|
||||||
|
router_scores = torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value)
|
||||||
|
router_scores = torch.nn.functional.sigmoid(router_scores.float()).to(router_scores.dtype)
|
||||||
|
return router_scores, router_logits
|
||||||
|
|
||||||
|
|
||||||
@use_kernel_forward_from_hub("Llama4TextMoe")
|
@use_kernel_forward_from_hub("Llama4TextMoe")
|
||||||
class Llama4TextMoe(nn.Module):
|
class Llama4TextMoe(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
@@ -135,28 +150,18 @@ class Llama4TextMoe(nn.Module):
|
|||||||
self.hidden_dim = config.hidden_size
|
self.hidden_dim = config.hidden_size
|
||||||
self.num_experts = config.num_local_experts
|
self.num_experts = config.num_local_experts
|
||||||
self.experts = Llama4TextExperts(config)
|
self.experts = Llama4TextExperts(config)
|
||||||
self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=False)
|
self.router = Llama4Router(config)
|
||||||
self.shared_expert = Llama4TextMLP(config)
|
self.shared_expert = Llama4TextMLP(config)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
||||||
router_logits = self.router(hidden_states)
|
router_scores, router_logits = self.router(hidden_states)
|
||||||
|
routed_in = hidden_states.repeat(router_scores.shape[1], 1)
|
||||||
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
|
|
||||||
|
|
||||||
router_scores = (
|
|
||||||
torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1)
|
|
||||||
)
|
|
||||||
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
|
|
||||||
|
|
||||||
routed_in = hidden_states.repeat(self.num_experts, 1)
|
|
||||||
routed_in = routed_in * router_scores.reshape(-1, 1)
|
routed_in = routed_in * router_scores.reshape(-1, 1)
|
||||||
routed_out = self.experts(routed_in)
|
routed_out = self.experts(routed_in)
|
||||||
|
|
||||||
out = self.shared_expert(hidden_states)
|
out = self.shared_expert(hidden_states)
|
||||||
out.add_(routed_out.reshape(self.num_experts, -1, self.hidden_dim).sum(dim=0))
|
out.add_(routed_out.reshape(router_scores.shape[1], -1, routed_out.shape[-1]).sum(dim=0))
|
||||||
|
return out, router_logits
|
||||||
return out, router_scores
|
|
||||||
|
|
||||||
|
|
||||||
class Llama4TextRotaryEmbedding(nn.Module):
|
class Llama4TextRotaryEmbedding(nn.Module):
|
||||||
@@ -383,8 +388,6 @@ class Llama4TextDecoderLayer(GradientCheckpointingLayer):
|
|||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
|
||||||
output_router_logits: Optional[bool] = False,
|
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
@@ -395,12 +398,11 @@ class Llama4TextDecoderLayer(GradientCheckpointingLayer):
|
|||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
attention_states, self_attn_weights = self.self_attn(
|
attention_states, _ = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -409,23 +411,12 @@ class Llama4TextDecoderLayer(GradientCheckpointingLayer):
|
|||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
hidden_states = self.feed_forward(hidden_states)
|
hidden_states = self.feed_forward(hidden_states)
|
||||||
if self.is_moe_layer:
|
if self.is_moe_layer:
|
||||||
hidden_states, router_logits = hidden_states
|
hidden_states, _ = hidden_states
|
||||||
else:
|
|
||||||
router_logits = None
|
|
||||||
hidden_states = residual + hidden_states.view(residual.shape)
|
hidden_states = residual + hidden_states.view(residual.shape)
|
||||||
outputs = (hidden_states,)
|
return hidden_states
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
outputs += (self_attn_weights,)
|
|
||||||
|
|
||||||
if output_router_logits:
|
|
||||||
outputs += (router_logits,)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
@@ -472,6 +463,11 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
|||||||
_no_split_modules = ["Llama4TextDecoderLayer"]
|
_no_split_modules = ["Llama4TextDecoderLayer"]
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
config: Llama4TextConfig
|
config: Llama4TextConfig
|
||||||
|
_can_record_outputs = {
|
||||||
|
"attentions": Llama4TextAttention,
|
||||||
|
"hidden_states": Llama4TextDecoderLayer,
|
||||||
|
"router_logits": Llama4TextMoe,
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, config: Llama4TextConfig):
|
def __init__(self, config: Llama4TextConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -489,7 +485,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
|||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
@can_return_tuple
|
@check_model_inputs
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -499,28 +495,12 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
|||||||
past_key_values: Optional[Cache] = None,
|
past_key_values: Optional[Cache] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs: Unpack[TransformersKwargs],
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
) -> Union[tuple, BaseModelOutputWithPast]:
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training and use_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
|
inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
|
||||||
|
|
||||||
@@ -558,42 +538,22 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
|||||||
# create position embeddings to be shared across the decoder layers
|
# create position embeddings to be shared across the decoder layers
|
||||||
freq_cis = self.rotary_emb(hidden_states, position_ids)
|
freq_cis = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
# decoder layers
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
all_self_attns = () if output_attentions else None
|
|
||||||
|
|
||||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||||
if output_hidden_states:
|
hidden_states = decoder_layer(
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_value=past_key_values,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=freq_cis,
|
position_embeddings=freq_cis,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attns += (layer_outputs[1],)
|
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=past_key_values if use_cache else None,
|
past_key_values=past_key_values if use_cache else None,
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -630,9 +590,6 @@ class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
|
|||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
**kwargs: Unpack[TransformersKwargs],
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
@@ -659,13 +616,6 @@ class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
|
|||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
```"""
|
```"""
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
@@ -673,9 +623,6 @@ class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
|
|||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=True,
|
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2634,7 +2634,14 @@ class Trainer:
|
|||||||
|
|
||||||
self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
|
self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
|
||||||
|
|
||||||
self.optimizer.step()
|
context = contextlib.nullcontext
|
||||||
|
if self.is_tp_enabled:
|
||||||
|
from torch.distributed._tensor.experimental import implicit_replication
|
||||||
|
|
||||||
|
context = implicit_replication
|
||||||
|
|
||||||
|
with context():
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
|
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
|
||||||
|
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ class TestTensorParallel(TestCasePlus):
|
|||||||
|
|
||||||
assert has_dtensor == 1, "TP model must has DTensor"
|
assert has_dtensor == 1, "TP model must has DTensor"
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
tokenizer = AutoTokenizer.from_pretrained(model_id, legacy=False)
|
||||||
prompt = "Can I help"
|
prompt = "Can I help"
|
||||||
|
|
||||||
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
|
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
|
||||||
|
|||||||
Reference in New Issue
Block a user