Add back _tp_plan attribute (#39944)

* Update modeling_utils.py

* make sure we update with the module's plan

* use public api

* oups

* update

* fix failing test

* Update src/transformers/integrations/tensor_parallel.py

* Update src/transformers/integrations/tensor_parallel.py

* fix

* make the API more friendly!

* fix tests

* fix styling

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Rishub Tamirisa
2025-08-20 06:29:55 -07:00
committed by GitHub
parent a97213d131
commit c50f140be2
5 changed files with 213 additions and 26 deletions

View File

@@ -23,7 +23,12 @@ from os.path import abspath, dirname, join
import _pytest
import pytest
from transformers.testing_utils import HfDoctestModule, HfDocTestParser, is_torch_available, patch_torch_compile_force_graph
from transformers.testing_utils import (
HfDoctestModule,
HfDocTestParser,
is_torch_available,
patch_torch_compile_force_graph,
)
NOT_DEVICE_TESTS = {

View File

@@ -198,6 +198,7 @@ class PretrainedConfig(PushToHubMixin):
attribute_map: dict[str, str] = {}
base_model_tp_plan: Optional[dict[str, Any]] = None
base_model_pp_plan: Optional[dict[str, tuple[list[str]]]] = None
base_model_ep_plan: Optional[dict[str, tuple[list[str]]]] = None
_auto_class: Optional[str] = None
def __setattr__(self, key, value):

View File

@@ -1013,8 +1013,7 @@ def shard_and_distribute_module(
"""
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
tp_plan = model._tp_plan or {}
tp_plan.update(getattr(type(model), "_tp_plan", None) or {})
tp_plan = model.tp_plan or {}
module_to_tp = model.get_submodule(param_name) # TODO: can i loop over modules?
rank = int(rank)
current_shard_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
@@ -1079,42 +1078,26 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
def distribute_model(model, distributed_config, device_mesh, tp_size):
_plan = "_tp_plan"
tp_plan = (getattr(model, "_tp_plan", None) or {}).copy()
model._tp_plan = getattr(model.config, "base_model_tp_plan").copy()
model._tp_plan.update(tp_plan)
model._tp_size = tp_size
model._device_mesh = device_mesh
if distributed_config is not None:
if isinstance(distributed_config, dict):
distributed_config = DistributedConfig.from_dict(distributed_config)
if distributed_config.enable_expert_parallel:
_plan = "_ep_plan"
model._tp_plan = getattr(model.config, "base_model_ep_plan", model._tp_plan).copy()
# now fetch my childrens
for name, module in model.named_children():
if plan := getattr(module, _plan, getattr(module, "tp_plan", None)):
model._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
if hasattr(module, "config"):
plan = getattr(module.config, f"base_model{_plan}", {})
if plan == {}:
plan = getattr(module.config, "base_model_tp_plan", {})
model._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
if model._tp_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available:
for v in model._tp_plan.values():
model.config.distributed_config = distributed_config
model_plan = model.tp_plan
if model_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available:
for v in model_plan.values():
if v not in ALL_PARALLEL_STYLES:
raise ValueError(f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}")
for name, module in model.named_modules():
if not getattr(module, "_is_hooked", False):
from transformers.integrations.tensor_parallel import add_tensor_parallel_hooks_to_module
plan = _get_parameter_tp_plan(parameter_name=name, tp_plan=model._tp_plan, is_weight=False)
plan = _get_parameter_tp_plan(parameter_name=name, tp_plan=model_plan, is_weight=False)
add_tensor_parallel_hooks_to_module(
model=model,
module=module,
tp_plan=model._tp_plan,
tp_plan=model_plan,
layer_name="",
current_module_plan=plan,
device_mesh=device_mesh,

View File

@@ -2276,7 +2276,87 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
)
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else {}
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
self._ep_plan = self.config.base_model_ep_plan.copy() if self.config.base_model_ep_plan is not None else {}
for name, module in self.named_children():
if plan := getattr(module, "_ep_plan", None):
self._ep_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
if plan := getattr(module, "_tp_plan", None):
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
if plan := getattr(module, "_pp_plan", None):
self._pp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
@property
def tp_plan(self) -> dict[str, str]:
"""
The full tp plan for the model's modules
"""
if hasattr(self.config, "distributed_config") and self.config.distributed_config.enable_expert_parallel:
return self._ep_plan
return self._tp_plan
@property
def pp_plan(self) -> dict[str, tuple[str, str]]:
return self._pp_plan
@tp_plan.setter
def tp_plan(self, plan: dict[str, str]):
if plan is not None:
# Validate that all parallel styles in the plan are supported
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES
for layer_pattern, parallel_style in plan.items():
if parallel_style not in ALL_PARALLEL_STYLES:
raise ValueError(
f"Unsupported tensor parallel style '{parallel_style}' for layer '{layer_pattern}'. "
f"Supported styles are {list(ALL_PARALLEL_STYLES.keys())}"
)
# Validate that the layer patterns match existing model structure
# We check this by getting all parameter names and seeing if any match the patterns
if hasattr(self, "named_parameters"):
model_param_names = [name for name, _ in self.named_parameters()]
if model_param_names: # Only validate if model has parameters
import re
for layer_pattern in plan.keys():
# Convert pattern to regex (replace * with .*)
regex_pattern = layer_pattern.replace("*", r"\d+")
pattern_matched = False
for param_name in model_param_names:
if re.match(regex_pattern, param_name):
pattern_matched = True
break
if not pattern_matched:
# Try more flexible matching - check if pattern components exist
pattern_parts = layer_pattern.split(".")
flexible_matched = False
for param_name in model_param_names:
param_parts = param_name.split(".")
if len(pattern_parts) <= len(param_parts):
match_count = 0
for i, pattern_part in enumerate(pattern_parts):
if pattern_part == "*":
match_count += 1
elif i < len(param_parts) and pattern_part == param_parts[i]:
match_count += 1
if match_count == len(pattern_parts):
flexible_matched = True
break
if not flexible_matched:
import warnings
warnings.warn(
f"Layer pattern '{layer_pattern}' does not match any parameters in the model. "
f"This rule may not be applied during tensor parallelization."
)
self._tp_plan = plan if plan is not None else {}
@pp_plan.setter
def pp_plan(self, plan: dict[str, tuple[str, str]]):
self._pp_plan = plan
def dequantize(self):
"""

View File

@@ -215,6 +215,124 @@ class TestTensorParallel(TestCasePlus):
del non_tp_tensor, tp_tensor
class TestTensorParallelProperties(TestCasePlus):
def test_tp_plan_property_setter_getter(self):
"""Test that tp_plan property can be set and retrieved correctly."""
from transformers import AutoModelForCausalLM
model_id = "JackFram/llama-68m"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
# Test setting empty plan
model.tp_plan = {}
self.assertEqual(model.tp_plan, {})
# Test setting a valid plan
valid_plan = {"model.layers.*.self_attn.q_proj": "colwise"}
model.tp_plan = valid_plan
self.assertEqual(model.tp_plan, valid_plan)
# Test updating the plan
model.tp_plan.update({"model.layers.*.self_attn.k_proj": "colwise"})
expected_plan = {"model.layers.*.self_attn.q_proj": "colwise", "model.layers.*.self_attn.k_proj": "colwise"}
self.assertEqual(model.tp_plan, expected_plan)
# Test overriding existing entry
model.tp_plan.update({"model.layers.*.self_attn.q_proj": "colwise_rep"})
expected_plan = {
"model.layers.*.self_attn.q_proj": "colwise_rep",
"model.layers.*.self_attn.k_proj": "colwise",
}
self.assertEqual(model.tp_plan, expected_plan)
def test_tp_plan_validation_invalid_style(self):
"""Test that invalid parallel styles are rejected."""
from transformers import AutoModelForCausalLM
model_id = "JackFram/llama-68m"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
# Test invalid parallel style
with self.assertRaises(ValueError) as context:
model.tp_plan = {"layers.*.self_attn.q_proj": "invalid_style"}
self.assertIn("Unsupported tensor parallel style 'invalid_style'", str(context.exception))
self.assertIn("Supported styles are", str(context.exception))
def test_tp_plan_validation_nonexistent_layer_warning(self):
"""Test that warnings are issued for non-existent layer patterns."""
import warnings
from transformers import AutoModelForCausalLM
model_id = "JackFram/llama-68m"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
# Test warning for non-existent layer pattern
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
model.tp_plan = {"nonexistent.*.layer": "colwise"}
# Check that a warning was issued
self.assertTrue(len(w) > 0)
warning_message = str(w[0].message)
self.assertIn("Layer pattern 'nonexistent.*.layer' does not match any parameters", warning_message)
def test_tp_plan_valid_layer_patterns(self):
"""Test that valid layer patterns are accepted without warnings."""
import warnings
from transformers import AutoModelForCausalLM
model_id = "JackFram/llama-68m"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
# Test valid layer patterns that should match the model structure
valid_plans = [
{"model.layers.*.self_attn.q_proj": "colwise"},
{"model.layers.*.self_attn.k_proj": "rowwise"},
{"model.layers.*.mlp.gate_proj": "colwise_rep"},
]
for plan in valid_plans:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
model.tp_plan = plan
# Filter out any warnings that are not about layer patterns
layer_warnings = [
warning
for warning in w
if "Layer pattern" in str(warning.message)
and "does not match any parameters" in str(warning.message)
]
# Should not have layer pattern warnings for valid patterns
self.assertEqual(
len(layer_warnings),
0,
f"Unexpected warning for valid pattern {plan}: {[str(w.message) for w in layer_warnings]}",
)
# Verify the final plan was set correctly
self.assertEqual(model.tp_plan, valid_plans[-1])
def test_tp_plan_none_handling(self):
"""Test that None values are handled correctly."""
from transformers import AutoModelForCausalLM
model_id = "JackFram/llama-68m"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
# Test setting None
model.tp_plan = None
self.assertEqual(model.tp_plan, {})
# Test setting a plan after None
model.tp_plan = {"model.layers.*.self_attn.q_proj": "colwise"}
self.assertEqual(model.tp_plan, {"model.layers.*.self_attn.q_proj": "colwise"})
@require_torch_multi_accelerator
class TestTensorParallelAccelerator(TestTensorParallel):
nproc_per_node = backend_device_count(torch_device)