* EP + updates

Co-authored-by: Nouamane Tazi <NouamaneTazi@users.noreply.github.com>
Co-authored-by: drbh <drbh@users.noreply.github.com>

* remove unrelated change

* not working yet but let's see where it goes!

* update the api a bit

* udpate

* where I am at for now

* fix ep

* refactor the API

* yups

* fix

* fixup

* clean modeling

* just support llama4 for now!

* properly avoid

* fix

* nits

* Update src/transformers/models/llama4/modeling_llama4.py

* Update src/transformers/integrations/tensor_parallel.py

* style

* ,,,,

* update

---------

Co-authored-by: Nouamane Tazi <NouamaneTazi@users.noreply.github.com>
Co-authored-by: drbh <drbh@users.noreply.github.com>
This commit is contained in:
Arthur
2025-07-25 19:46:17 +02:00
committed by GitHub
parent abaa043d60
commit 300d42a43e
9 changed files with 436 additions and 186 deletions

View File

@@ -0,0 +1,33 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ..utils import _LazyModule
_import_structure = {
"configuration_utils": ["DistributedConfig"],
}
if TYPE_CHECKING:
from .configuration_utils import (
DistributedConfig,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@@ -0,0 +1,111 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import os
from dataclasses import dataclass
from typing import Any, Union
@dataclass
class DistributedConfig:
"""
Base class for distributed configs
"""
enable_expert_parallel: bool = False
# TODO: add tp_plan, pp_plan, device_mesh etc..
@classmethod
def from_dict(cls, config_dict, **kwargs):
"""
Constructs a DistributedConfig instance from a dictionary of parameters.
Args:
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
**kwargs: Additional keyword arguments to override dictionary values.
Returns:
DistributedConfig: Instance of DistributedConfig constructed from the dictionary.
"""
config = cls(**config_dict)
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
return config
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
"""
Save this instance to a JSON file.
Args:
json_file_path (`str` or `os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved.
use_diff (`bool`, *optional*, defaults to `True`):
If set to `True`, only the difference between the config instance and the default
`QuantizationConfig()` is serialized to JSON file.
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
config_dict = self.to_dict()
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
writer.write(json_string)
def to_dict(self) -> dict[str, Any]:
"""
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
return copy.deepcopy(self.__dict__)
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
def __iter__(self):
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
for attr, value in copy.deepcopy(self.__dict__).items():
yield attr, value
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
def to_json_string(self):
"""
Serializes this instance to a JSON formatted string.
Returns:
str: JSON formatted string representing the configuration instance.
"""
return json.dumps(self.__dict__, indent=2) + "\n"
def update(self, **kwargs):
"""
Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
returning all the unused kwargs.
Args:
kwargs (`Dict[str, Any]`):
Dictionary of attributes to tentatively update this class.
Returns:
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
"""
to_remove = []
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
to_remove.append(key)
# Remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs

View File

@@ -52,6 +52,12 @@ try:
layer_name="TritonLlamaMLP",
)
},
"MegaBlocksMoeMLP": {
"cuda": LayerRepository(
repo_id="kernels-community/megablocks",
layer_name="MegaBlocksMoeMLP",
)
},
}
register_kernel_mapping(_KERNEL_MAPPING)

View File

@@ -23,6 +23,7 @@ import torch
import torch.distributed as dist
from torch import nn
from ..distributed import DistributedConfig
from ..utils import is_torch_greater_or_equal, logging
from ..utils.generic import GeneralInterface
@@ -90,7 +91,7 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None):
device_map = tp_device
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
return tp_device, device_map, device_mesh
return tp_device, device_map, device_mesh, tp_size
def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int]:
@@ -119,20 +120,22 @@ def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int
return [single_size] * blocks
def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str]) -> str | None:
def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weight=True) -> str | None:
"""
Get the TP style for a parameter from the TP plan.
The TP plan is a dictionary that maps parameter names to TP styles.
The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight").
The `is_weight` is important because for weights, we want to support `.weights` and `.bias` cases seamlessly! but
not parrent classes for `post_init` calls
"""
generic_param_name = re.sub(r"\d+", "*", parameter_name)
if generic_param_name in tp_plan:
return tp_plan[generic_param_name]
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan and is_weight:
return tp_plan[generic_param_name.rsplit(".", 1)[0]]
else:
return None
return None
str_to_torch_dtype = {
@@ -198,8 +201,10 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
slice_dtype = slice_.get_dtype()
# Handle F8_E4M3 dtype by converting to float16 before slicing
# Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn'
if slice_dtype == "F8_E4M3":
casted = False
if slice_dtype == "F8_E4M3" or slice_dtype == "F8_E5M2":
slice_ = slice_[...].to(torch.float16)
casted = True
if dim == 0:
tensor = slice_[tensors_slices, ...]
@@ -209,7 +214,11 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
tensor = slice_[..., tensors_slices]
else:
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
return tensor.to(str_to_torch_dtype[slice_dtype])
if casted:
return tensor
else:
return tensor.to(str_to_torch_dtype[slice_dtype])
def repack_weights(
@@ -423,16 +432,27 @@ class GatherParallel(TensorParallelLayer):
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
mod.expert_parallel_group = device_mesh.get_group()
if inputs and isinstance(inputs[0], DTensor):
inputs = inputs[0].to_local()
return inputs
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
# this op cannot be async, otherwise it completely breaks the outputs of models
torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False)
if isinstance(outputs, torch.Tensor):
dist.all_reduce(outputs, op=dist.ReduceOp.SUM, async_op=False)
else:
dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False)
return outputs
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
distribute_module(
module,
device_mesh,
partial(self._prepare_input_fn, None, None),
partial(self._prepare_output_fn, None, None),
)
class IsolatedParallel(TensorParallelLayer):
"""
@@ -453,6 +473,14 @@ class IsolatedParallel(TensorParallelLayer):
# TODO: figure out dynamo support for instance method and switch this to instance method
return outputs
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
param = param[...].to(param_casting_dtype)
if to_contiguous:
param = param.contiguous()
param = param / device_mesh.size() # TODO should be optionable
# TODO: assumes parent module will allreduce the output afterwards (e.g rowlinear bias is IsolatedParallel and parent module is GatherParallel)
return param
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
distribute_module(
module,
@@ -773,6 +801,108 @@ class SequenceParallel(TensorParallelLayer):
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
class GroupedGemmParallel(TensorParallelLayer):
"""
Applies Expert Parallelism to MoE experts by loading the correct experts on each device.
"""
def __init__(self):
super().__init__()
self.use_dtensor = False
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
ep_rank = rank
global_num_experts = empty_param.shape[0]
if global_num_experts % device_mesh.size() != 0:
raise ValueError(
f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0"
)
local_num_experts = global_num_experts // device_mesh.size()
param = param[ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts].to(param_casting_dtype)
if to_contiguous:
param = param.contiguous()
if "gate_up" in param_type and False:
param = torch.cat([param[..., ::2], param[..., 1::2]], dim=-1)
return param
class RouterParallel(TensorParallelLayer):
"""
Allows to reshape the router scores to support running expert parallel.
"""
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self.use_dtensor = False
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
input_tensor = inputs[0]
if isinstance(input_tensor, DTensor):
raise NotImplementedError("RouterParallel does not support DTensor input for now")
return input_tensor
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
"""
Imagine if you had 4 tokens, top_k = 4, and 128experts.
With EP = 8.
Imagine router_indices being:
[ 52, 42, 119, 67],
[102, 89, 61, 40],
[ 82, 103, 4, 34],
[ 93, 23, 109, 11],
then you can map which rank should be getting which values
[3, 2, 7, 4],
[6, 5, 3, 2],
[5, 6, 0, 2],
[5, 1, 6, 0],
Thus for say rank 0, you fill with 0 the index tensor
[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 4, 0],
[ 0, 0, 0, 11],
This works well. For another rank you need to make sure you round to num_local_expert
because the next operation will one hot encode the router index vector.
This allows us to know directly which local expert is hit.
Similarly the scores are indexed with something created form
router_indices.
The kinda naive training loop that we use for device_map "auto" uses a similar logic.
Here we are just making each rank believe that he is alone, and he computes his part of the hiddenstates.
"""
ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size()
num_local_experts = mod.num_experts // ep_size
router_scores, router_indices = outputs
router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts]
router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, 0)
router_indices = router_indices % num_local_experts
return router_scores, router_indices
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# TODO: i'd like for this to be the default
param = param[...].to(param_casting_dtype)
if to_contiguous:
param = param.contiguous()
return param
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
# TODO: need an abstract Parallel class that is different from TensorParallelLayer
distribute_module(
module,
device_mesh,
partial(self._prepare_input_fn, None, None),
partial(self._prepare_output_fn, None, None),
)
class ParallelInterface(GeneralInterface):
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
# a new instance is created (in order to locally override a given entry)
@@ -789,6 +919,8 @@ class ParallelInterface(GeneralInterface):
"local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
"sequence_parallel": SequenceParallel(),
"replicate": ReplicateParallel(),
"grouped_gemm": GroupedGemmParallel(),
"ep_router": RouterParallel(),
}
if is_torch_greater_or_equal("2.5") and _torch_distributed_available
else {}
@@ -841,25 +973,17 @@ def replace_state_dict_local_with_dtensor(
return state_dict
def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh):
"""
Add hooks to the module holding the layer. Meaning:
```
class MyModel(nn.Module):
def __init__(self):
self.layer = nn.Linear(10, 10)
```
has state_dict like:
```
{
"layer.weight": torch.Tensor,
"layer.bias": torch.Tensor
}
```
we add hooks to `MyModel` as well as `layer` to make sure that the tensors are correctly sharded and gathered.
"""
def add_tensor_parallel_hooks_to_module(
model, module, tp_plan, layer_name, current_module_plan, device_mesh, parameter_name=None
):
r"""
This function is called in `PretrainedModel.post_init()`. It is responsible of adding hooks
to the modules of the `model`, based on the `PretrainedModel._tp_plan`.
# 1. We add hooks to the layer being loaded:
This is the place where we add the `pre_forward` and `post_forwards` hooks. These are defined
for each `TensorParallelLayer` as `_prepare_input_fn` and `_prepare_output_fn`.
"""
if current_module_plan is not None:
tp_layer = ALL_PARALLEL_STYLES[current_module_plan]
try:
@@ -868,26 +992,19 @@ def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, curr
print(
f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}"
)
module._hf_tp_plan = current_module_plan
module.__repr__ = lambda: f"{module.__repr__()}\nTP Plan: {current_module_plan}"
# 2. We add hooks to the parent module if needed
if "." in layer_name:
parent_layer_name = layer_name.rsplit(".", 1)[0]
generic_name = re.sub(r"\d+", "*", parent_layer_name)
# The module itself needs hooks
if module_plan := tp_plan.get(generic_name, False):
tp_layer = ALL_PARALLEL_STYLES[module_plan]
module_to_tp_ = model.get_submodule(parent_layer_name)
tp_layer.prepare_module_tp(module_to_tp_, device_mesh)
module_to_tp_._hf_tp_plan = current_module_plan
module_to_tp_.__repr__ = lambda: f"{module_to_tp_.__repr__()}\nTP Plan: {current_module_plan}"
def shard_and_distribute_module(
model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh
):
): # TODO: rename to shard_and_distribute_param
r"""
This function is called in `from_pretrained` when loading a model's checkpoints.
It receives the pointer to the parameter (or the parameter itself) and takes care of "sharding".
All process run this function, so they just load the partition of the tensor that they require.
Main uses cases:
- column / rowise parallelism, you just shard all the weights of the layer (weight and bias)
- packed layers: you slice the weights, then shard like above
@@ -898,39 +1015,33 @@ def shard_and_distribute_module(
"""
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
tp_plan = model._tp_plan
module_to_tp = model.get_submodule(param_name)
module_to_tp = model.get_submodule(param_name) # TODO: can i loop over modules?
rank = int(rank)
current_shard_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
current_module_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
if dist.get_rank() == 0:
if current_shard_plan is None:
logger.info(f"Tensor sharding plan for {param_name} not found, using default 'replicate' plan.")
else:
logger.info(f"Tensor sharding plan for {param_name}: {current_shard_plan}")
if current_module_plan is None:
current_module_plan = "replicate"
if dist.get_rank() == 0:
logger.info(f"Tensor parallel plan for {param_name} not found, using default 'replicate' plan.")
if current_shard_plan is not None:
try:
tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
param = tp_layer.partition_tensor(
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
)
except NotImplementedError as e:
print(
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
)
else:
if dist.get_rank() == 0:
logger.info(f"Tensor parallel plan for {param_name}: {current_module_plan}")
# Add hooks to the module if not done yet
# add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
if not getattr(module_to_tp, "_is_hooked", False):
add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
module_to_tp._is_hooked = True
try:
tp_layer = ALL_PARALLEL_STYLES[current_module_plan]
param = tp_layer.partition_tensor(
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
)
except NotImplementedError as e:
print(
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
)
param = param[:].to(param_casting_dtype)
# SUPER IMPORTANT we have to use setattr
# otherwise loading is crazy slow
if not isinstance(param, torch.nn.Parameter):
param = torch.nn.Parameter(param, requires_grad=param.is_floating_point())
param = torch.nn.Parameter(param, requires_grad=empty_param.is_floating_point())
setattr(module_to_tp, param_type, param)
# module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)
return param
@@ -965,3 +1076,43 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
logger.warning(f"The following TP rules were not applied on any of the layers: {unused_rules}")
if len(unsharded_layers) > 0:
logger.warning(f"The following layers were not sharded: {', '.join(unsharded_layers)}")
def distribute_model(model, distributed_config, device_mesh, tp_size):
_plan = "_tp_plan"
model._tp_plan = getattr(model.config, "base_model_tp_plan").copy()
if distributed_config is not None:
distributed_config = DistributedConfig.from_config(distributed_config)
if distributed_config.enable_expert_parallel:
_plan = "_ep_plan"
model._tp_plan = getattr(model.config, "base_model_ep_plan", model._tp_plan).copy()
# now fetch my childrens
for name, module in model.named_children():
if plan := getattr(module, _plan, getattr(module, "tp_plan", None)):
model._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
if hasattr(module, "config"):
plan = getattr(module.config, f"base_model{_plan}", {})
if plan == {}:
plan = getattr(module.config, "base_model_tp_plan", {})
model._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
if model._tp_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available:
for v in model._tp_plan.values():
if v not in ALL_PARALLEL_STYLES:
raise ValueError(f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}")
for name, module in model.named_modules():
if not getattr(module, "_is_hooked", False):
from transformers.integrations.tensor_parallel import add_tensor_parallel_hooks_to_module
plan = _get_parameter_tp_plan(parameter_name=name, tp_plan=model._tp_plan, is_weight=False)
add_tensor_parallel_hooks_to_module(
model=model,
module=module,
tp_plan=model._tp_plan,
layer_name="",
current_module_plan=plan,
device_mesh=device_mesh,
)
module._is_hooked = True
return model

View File

@@ -63,8 +63,8 @@ from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.sdpa_paged import sdpa_attention_paged_forward
from .integrations.tensor_parallel import (
ALL_PARALLEL_STYLES,
_get_parameter_tp_plan,
distribute_model,
initialize_tensor_parallelism,
repack_weights,
replace_state_dict_local_with_dtensor,
@@ -2218,6 +2218,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
"""
A method executed at the end of each Transformer model initialization, to execute code that needs the model's
modules properly initialized (such as weight initialization).
This is also used when the user is running distributed code. We add hooks to the modules here, according to
the model's tp_plan!
"""
self.init_weights()
self._backward_compatibility_gradient_checkpointing()
@@ -2250,17 +2253,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
for name, module in self.named_children():
if plan := getattr(module, "_tp_plan", None):
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
if self._tp_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available:
for v in self._tp_plan.values():
if v not in ALL_PARALLEL_STYLES:
raise ValueError(
f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}"
)
def dequantize(self):
"""
@@ -4568,6 +4560,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
load_in_8bit = kwargs.pop("load_in_8bit", False)
load_in_4bit = kwargs.pop("load_in_4bit", False)
quantization_config = kwargs.pop("quantization_config", None)
distributed_config = kwargs.pop("distributed_config", None)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)
@@ -4588,6 +4581,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
):
key_mapping = cls._checkpoint_conversion_mapping
if distributed_config is not None:
tp_plan = "auto"
# Not used anymore -- remove them from the kwargs
_ = kwargs.pop("resume_download", None)
_ = kwargs.pop("mirror", None)
@@ -4619,16 +4615,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
# `device_map` pointing to the correct device
if tp_plan is not None:
if device_mesh is None:
tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None)
tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(tp_plan, tp_size=tp_size)
else:
if "tp" not in device_mesh.mesh_dim_names:
raise ValueError(
"When using `tp_plan`, the `device_mesh` must contain a 'tp' dimension. "
"Please provide a valid `device_mesh`."
)
device_mesh = device_mesh["tp"]
tp_size = device_mesh["tp"].size()
device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
# TODO: make device_mesh support multiple dimensions
if device_mesh.ndim > 1:
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
if tp_size is None:
tp_size = torch.distributed.get_world_size()
@@ -4928,23 +4920,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
)
config.name_or_path = pretrained_model_name_or_path
# Instantiate model.
model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
with ContextManagers(model_init_context):
# Let's make sure we don't run the init function of buffer modules
model = cls(config, *model_args, **model_kwargs)
if _torch_distributed_available and device_mesh is not None:
model = distribute_model(model, distributed_config, device_mesh, tp_size)
# Make sure to tie the weights correctly
model.tie_weights()
# Last check for tp
if device_mesh is not None and not model.supports_tp_plan:
if config.base_model_tp_plan is None and config.get_text_config().base_model_tp_plan is None:
raise NotImplementedError("This model does not have a tensor parallel plan.")
# make sure we use the model's config since the __init__ call might have copied it
config = model.config
@@ -5025,11 +5012,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
key_mapping=key_mapping,
weights_only=weights_only,
)
# record tp degree the model sharded to
model._tp_size = tp_size
model._device_mesh = device_mesh
# make sure token embedding weights are still tied if needed
model.tie_weights()

View File

@@ -265,6 +265,19 @@ class Llama4TextConfig(PretrainedConfig):
"layers.*.feed_forward.down_proj": "local_rowwise",
"layers.*.feed_forward": "gather",
}
base_model_ep_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.feed_forward.experts.gate_up_proj": "grouped_gemm", # row because not linear
"layers.*.feed_forward.experts.down_proj": "grouped_gemm", # col because not linear
"layers.*.feed_forward.experts": "gather", # all reduce
"layers.*.feed_forward.gate_proj": "local_colwise",
"layers.*.feed_forward.up_proj": "local_colwise",
"layers.*.feed_forward.down_proj": "local_rowwise",
"layers.*.feed_forward.router": "ep_router",
}
def __init__(
self,

View File

@@ -26,7 +26,7 @@ from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations.hub_kernels import use_kernel_forward_from_hub
from ...integrations import use_kernel_forward_from_hub
from ...masking_utils import create_causal_mask, create_chunked_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
@@ -35,6 +35,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import check_model_inputs
from .configuration_llama4 import Llama4Config, Llama4TextConfig
@@ -65,7 +66,7 @@ class Llama4TextExperts(nn.Module):
Returns:
torch.Tensor
"""
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
hidden_states = hidden_states.view(self.gate_up_proj.shape[0], -1, self.hidden_size)
gate_up = torch.bmm(hidden_states, self.gate_up_proj)
gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj)
@@ -127,6 +128,20 @@ class Llama4TextRMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
class Llama4Router(nn.Linear):
def __init__(self, config):
super().__init__(config.hidden_size, config.num_local_experts, bias=False)
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
def forward(self, hidden_states):
router_logits = super().forward(hidden_states)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
router_scores = torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value)
router_scores = torch.nn.functional.sigmoid(router_scores.float()).to(router_scores.dtype)
return router_scores, router_logits
@use_kernel_forward_from_hub("Llama4TextMoe")
class Llama4TextMoe(nn.Module):
def __init__(self, config):
@@ -135,28 +150,18 @@ class Llama4TextMoe(nn.Module):
self.hidden_dim = config.hidden_size
self.num_experts = config.num_local_experts
self.experts = Llama4TextExperts(config)
self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=False)
self.router = Llama4Router(config)
self.shared_expert = Llama4TextMLP(config)
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = self.router(hidden_states)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
router_scores = (
torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1)
)
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
routed_in = hidden_states.repeat(self.num_experts, 1)
router_scores, router_logits = self.router(hidden_states)
routed_in = hidden_states.repeat(router_scores.shape[1], 1)
routed_in = routed_in * router_scores.reshape(-1, 1)
routed_out = self.experts(routed_in)
out = self.shared_expert(hidden_states)
out.add_(routed_out.reshape(self.num_experts, -1, self.hidden_dim).sum(dim=0))
return out, router_scores
out.add_(routed_out.reshape(router_scores.shape[1], -1, routed_out.shape[-1]).sum(dim=0))
return out, router_logits
class Llama4TextRotaryEmbedding(nn.Module):
@@ -383,8 +388,6 @@ class Llama4TextDecoderLayer(GradientCheckpointingLayer):
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
@@ -395,12 +398,11 @@ class Llama4TextDecoderLayer(GradientCheckpointingLayer):
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
attention_states, self_attn_weights = self.self_attn(
attention_states, _ = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
@@ -409,23 +411,12 @@ class Llama4TextDecoderLayer(GradientCheckpointingLayer):
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.feed_forward(hidden_states)
if self.is_moe_layer:
hidden_states, router_logits = hidden_states
else:
router_logits = None
hidden_states, _ = hidden_states
hidden_states = residual + hidden_states.view(residual.shape)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if output_router_logits:
outputs += (router_logits,)
return outputs
return hidden_states
@auto_docstring
@@ -472,6 +463,11 @@ class Llama4TextModel(Llama4PreTrainedModel):
_no_split_modules = ["Llama4TextDecoderLayer"]
base_model_prefix = "model"
config: Llama4TextConfig
_can_record_outputs = {
"attentions": Llama4TextAttention,
"hidden_states": Llama4TextDecoderLayer,
"router_logits": Llama4TextMoe,
}
def __init__(self, config: Llama4TextConfig):
super().__init__(config)
@@ -489,7 +485,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
@check_model_inputs
@auto_docstring
def forward(
self,
@@ -499,28 +495,12 @@ class Llama4TextModel(Llama4PreTrainedModel):
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
@@ -558,42 +538,22 @@ class Llama4TextModel(Llama4PreTrainedModel):
# create position embeddings to be shared across the decoder layers
freq_cis = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=freq_cis,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@@ -630,9 +590,6 @@ class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
@@ -659,13 +616,6 @@ class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
@@ -673,9 +623,6 @@ class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)

View File

@@ -2634,7 +2634,14 @@ class Trainer:
self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
self.optimizer.step()
context = contextlib.nullcontext
if self.is_tp_enabled:
from torch.distributed._tensor.experimental import implicit_replication
context = implicit_replication
with context():
self.optimizer.step()
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)

View File

@@ -109,7 +109,7 @@ class TestTensorParallel(TestCasePlus):
assert has_dtensor == 1, "TP model must has DTensor"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id, legacy=False)
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)