* EP + updates

Co-authored-by: Nouamane Tazi <NouamaneTazi@users.noreply.github.com>
Co-authored-by: drbh <drbh@users.noreply.github.com>

* remove unrelated change

* not working yet but let's see where it goes!

* update the api a bit

* udpate

* where I am at for now

* fix ep

* refactor the API

* yups

* fix

* fixup

* clean modeling

* just support llama4 for now!

* properly avoid

* fix

* nits

* Update src/transformers/models/llama4/modeling_llama4.py

* Update src/transformers/integrations/tensor_parallel.py

* style

* ,,,,

* update

---------

Co-authored-by: Nouamane Tazi <NouamaneTazi@users.noreply.github.com>
Co-authored-by: drbh <drbh@users.noreply.github.com>
This commit is contained in:
Arthur
2025-07-25 19:46:17 +02:00
committed by GitHub
parent abaa043d60
commit 300d42a43e
9 changed files with 436 additions and 186 deletions

View File

@@ -63,8 +63,8 @@ from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.sdpa_paged import sdpa_attention_paged_forward
from .integrations.tensor_parallel import (
ALL_PARALLEL_STYLES,
_get_parameter_tp_plan,
distribute_model,
initialize_tensor_parallelism,
repack_weights,
replace_state_dict_local_with_dtensor,
@@ -2218,6 +2218,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
"""
A method executed at the end of each Transformer model initialization, to execute code that needs the model's
modules properly initialized (such as weight initialization).
This is also used when the user is running distributed code. We add hooks to the modules here, according to
the model's tp_plan!
"""
self.init_weights()
self._backward_compatibility_gradient_checkpointing()
@@ -2250,17 +2253,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
for name, module in self.named_children():
if plan := getattr(module, "_tp_plan", None):
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
if self._tp_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available:
for v in self._tp_plan.values():
if v not in ALL_PARALLEL_STYLES:
raise ValueError(
f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}"
)
def dequantize(self):
"""
@@ -4568,6 +4560,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
load_in_8bit = kwargs.pop("load_in_8bit", False)
load_in_4bit = kwargs.pop("load_in_4bit", False)
quantization_config = kwargs.pop("quantization_config", None)
distributed_config = kwargs.pop("distributed_config", None)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)
@@ -4588,6 +4581,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
):
key_mapping = cls._checkpoint_conversion_mapping
if distributed_config is not None:
tp_plan = "auto"
# Not used anymore -- remove them from the kwargs
_ = kwargs.pop("resume_download", None)
_ = kwargs.pop("mirror", None)
@@ -4619,16 +4615,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
# `device_map` pointing to the correct device
if tp_plan is not None:
if device_mesh is None:
tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None)
tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(tp_plan, tp_size=tp_size)
else:
if "tp" not in device_mesh.mesh_dim_names:
raise ValueError(
"When using `tp_plan`, the `device_mesh` must contain a 'tp' dimension. "
"Please provide a valid `device_mesh`."
)
device_mesh = device_mesh["tp"]
tp_size = device_mesh["tp"].size()
device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
# TODO: make device_mesh support multiple dimensions
if device_mesh.ndim > 1:
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
if tp_size is None:
tp_size = torch.distributed.get_world_size()
@@ -4928,23 +4920,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
)
config.name_or_path = pretrained_model_name_or_path
# Instantiate model.
model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
with ContextManagers(model_init_context):
# Let's make sure we don't run the init function of buffer modules
model = cls(config, *model_args, **model_kwargs)
if _torch_distributed_available and device_mesh is not None:
model = distribute_model(model, distributed_config, device_mesh, tp_size)
# Make sure to tie the weights correctly
model.tie_weights()
# Last check for tp
if device_mesh is not None and not model.supports_tp_plan:
if config.base_model_tp_plan is None and config.get_text_config().base_model_tp_plan is None:
raise NotImplementedError("This model does not have a tensor parallel plan.")
# make sure we use the model's config since the __init__ call might have copied it
config = model.config
@@ -5025,11 +5012,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
key_mapping=key_mapping,
weights_only=weights_only,
)
# record tp degree the model sharded to
model._tp_size = tp_size
model._device_mesh = device_mesh
# make sure token embedding weights are still tied if needed
model.tie_weights()