Add back _tp_plan attribute (#39944)
* Update modeling_utils.py * make sure we update with the module's plan * use public api * oups * update * fix failing test * Update src/transformers/integrations/tensor_parallel.py * Update src/transformers/integrations/tensor_parallel.py * fix * make the API more friendly! * fix tests * fix styling --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -23,7 +23,12 @@ from os.path import abspath, dirname, join
|
||||
import _pytest
|
||||
import pytest
|
||||
|
||||
from transformers.testing_utils import HfDoctestModule, HfDocTestParser, is_torch_available, patch_torch_compile_force_graph
|
||||
from transformers.testing_utils import (
|
||||
HfDoctestModule,
|
||||
HfDocTestParser,
|
||||
is_torch_available,
|
||||
patch_torch_compile_force_graph,
|
||||
)
|
||||
|
||||
|
||||
NOT_DEVICE_TESTS = {
|
||||
|
||||
@@ -198,6 +198,7 @@ class PretrainedConfig(PushToHubMixin):
|
||||
attribute_map: dict[str, str] = {}
|
||||
base_model_tp_plan: Optional[dict[str, Any]] = None
|
||||
base_model_pp_plan: Optional[dict[str, tuple[list[str]]]] = None
|
||||
base_model_ep_plan: Optional[dict[str, tuple[list[str]]]] = None
|
||||
_auto_class: Optional[str] = None
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
|
||||
@@ -1013,8 +1013,7 @@ def shard_and_distribute_module(
|
||||
|
||||
"""
|
||||
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
|
||||
tp_plan = model._tp_plan or {}
|
||||
tp_plan.update(getattr(type(model), "_tp_plan", None) or {})
|
||||
tp_plan = model.tp_plan or {}
|
||||
module_to_tp = model.get_submodule(param_name) # TODO: can i loop over modules?
|
||||
rank = int(rank)
|
||||
current_shard_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
|
||||
@@ -1079,42 +1078,26 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
|
||||
|
||||
|
||||
def distribute_model(model, distributed_config, device_mesh, tp_size):
|
||||
_plan = "_tp_plan"
|
||||
tp_plan = (getattr(model, "_tp_plan", None) or {}).copy()
|
||||
model._tp_plan = getattr(model.config, "base_model_tp_plan").copy()
|
||||
model._tp_plan.update(tp_plan)
|
||||
model._tp_size = tp_size
|
||||
model._device_mesh = device_mesh
|
||||
if distributed_config is not None:
|
||||
if isinstance(distributed_config, dict):
|
||||
distributed_config = DistributedConfig.from_dict(distributed_config)
|
||||
if distributed_config.enable_expert_parallel:
|
||||
_plan = "_ep_plan"
|
||||
model._tp_plan = getattr(model.config, "base_model_ep_plan", model._tp_plan).copy()
|
||||
|
||||
# now fetch my childrens
|
||||
for name, module in model.named_children():
|
||||
if plan := getattr(module, _plan, getattr(module, "tp_plan", None)):
|
||||
model._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
|
||||
if hasattr(module, "config"):
|
||||
plan = getattr(module.config, f"base_model{_plan}", {})
|
||||
if plan == {}:
|
||||
plan = getattr(module.config, "base_model_tp_plan", {})
|
||||
model._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
|
||||
|
||||
if model._tp_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
||||
for v in model._tp_plan.values():
|
||||
model.config.distributed_config = distributed_config
|
||||
model_plan = model.tp_plan
|
||||
if model_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
||||
for v in model_plan.values():
|
||||
if v not in ALL_PARALLEL_STYLES:
|
||||
raise ValueError(f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}")
|
||||
for name, module in model.named_modules():
|
||||
if not getattr(module, "_is_hooked", False):
|
||||
from transformers.integrations.tensor_parallel import add_tensor_parallel_hooks_to_module
|
||||
|
||||
plan = _get_parameter_tp_plan(parameter_name=name, tp_plan=model._tp_plan, is_weight=False)
|
||||
plan = _get_parameter_tp_plan(parameter_name=name, tp_plan=model_plan, is_weight=False)
|
||||
add_tensor_parallel_hooks_to_module(
|
||||
model=model,
|
||||
module=module,
|
||||
tp_plan=model._tp_plan,
|
||||
tp_plan=model_plan,
|
||||
layer_name="",
|
||||
current_module_plan=plan,
|
||||
device_mesh=device_mesh,
|
||||
|
||||
@@ -2276,7 +2276,87 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
)
|
||||
|
||||
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
||||
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
|
||||
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else {}
|
||||
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
|
||||
self._ep_plan = self.config.base_model_ep_plan.copy() if self.config.base_model_ep_plan is not None else {}
|
||||
for name, module in self.named_children():
|
||||
if plan := getattr(module, "_ep_plan", None):
|
||||
self._ep_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
|
||||
if plan := getattr(module, "_tp_plan", None):
|
||||
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
|
||||
if plan := getattr(module, "_pp_plan", None):
|
||||
self._pp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
|
||||
|
||||
@property
|
||||
def tp_plan(self) -> dict[str, str]:
|
||||
"""
|
||||
The full tp plan for the model's modules
|
||||
"""
|
||||
if hasattr(self.config, "distributed_config") and self.config.distributed_config.enable_expert_parallel:
|
||||
return self._ep_plan
|
||||
return self._tp_plan
|
||||
|
||||
@property
|
||||
def pp_plan(self) -> dict[str, tuple[str, str]]:
|
||||
return self._pp_plan
|
||||
|
||||
@tp_plan.setter
|
||||
def tp_plan(self, plan: dict[str, str]):
|
||||
if plan is not None:
|
||||
# Validate that all parallel styles in the plan are supported
|
||||
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES
|
||||
|
||||
for layer_pattern, parallel_style in plan.items():
|
||||
if parallel_style not in ALL_PARALLEL_STYLES:
|
||||
raise ValueError(
|
||||
f"Unsupported tensor parallel style '{parallel_style}' for layer '{layer_pattern}'. "
|
||||
f"Supported styles are {list(ALL_PARALLEL_STYLES.keys())}"
|
||||
)
|
||||
|
||||
# Validate that the layer patterns match existing model structure
|
||||
# We check this by getting all parameter names and seeing if any match the patterns
|
||||
if hasattr(self, "named_parameters"):
|
||||
model_param_names = [name for name, _ in self.named_parameters()]
|
||||
if model_param_names: # Only validate if model has parameters
|
||||
import re
|
||||
|
||||
for layer_pattern in plan.keys():
|
||||
# Convert pattern to regex (replace * with .*)
|
||||
regex_pattern = layer_pattern.replace("*", r"\d+")
|
||||
pattern_matched = False
|
||||
for param_name in model_param_names:
|
||||
if re.match(regex_pattern, param_name):
|
||||
pattern_matched = True
|
||||
break
|
||||
if not pattern_matched:
|
||||
# Try more flexible matching - check if pattern components exist
|
||||
pattern_parts = layer_pattern.split(".")
|
||||
flexible_matched = False
|
||||
for param_name in model_param_names:
|
||||
param_parts = param_name.split(".")
|
||||
if len(pattern_parts) <= len(param_parts):
|
||||
match_count = 0
|
||||
for i, pattern_part in enumerate(pattern_parts):
|
||||
if pattern_part == "*":
|
||||
match_count += 1
|
||||
elif i < len(param_parts) and pattern_part == param_parts[i]:
|
||||
match_count += 1
|
||||
if match_count == len(pattern_parts):
|
||||
flexible_matched = True
|
||||
break
|
||||
if not flexible_matched:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"Layer pattern '{layer_pattern}' does not match any parameters in the model. "
|
||||
f"This rule may not be applied during tensor parallelization."
|
||||
)
|
||||
|
||||
self._tp_plan = plan if plan is not None else {}
|
||||
|
||||
@pp_plan.setter
|
||||
def pp_plan(self, plan: dict[str, tuple[str, str]]):
|
||||
self._pp_plan = plan
|
||||
|
||||
def dequantize(self):
|
||||
"""
|
||||
|
||||
@@ -215,6 +215,124 @@ class TestTensorParallel(TestCasePlus):
|
||||
del non_tp_tensor, tp_tensor
|
||||
|
||||
|
||||
class TestTensorParallelProperties(TestCasePlus):
|
||||
def test_tp_plan_property_setter_getter(self):
|
||||
"""Test that tp_plan property can be set and retrieved correctly."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model_id = "JackFram/llama-68m"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
|
||||
|
||||
# Test setting empty plan
|
||||
model.tp_plan = {}
|
||||
self.assertEqual(model.tp_plan, {})
|
||||
|
||||
# Test setting a valid plan
|
||||
valid_plan = {"model.layers.*.self_attn.q_proj": "colwise"}
|
||||
model.tp_plan = valid_plan
|
||||
self.assertEqual(model.tp_plan, valid_plan)
|
||||
|
||||
# Test updating the plan
|
||||
model.tp_plan.update({"model.layers.*.self_attn.k_proj": "colwise"})
|
||||
expected_plan = {"model.layers.*.self_attn.q_proj": "colwise", "model.layers.*.self_attn.k_proj": "colwise"}
|
||||
self.assertEqual(model.tp_plan, expected_plan)
|
||||
|
||||
# Test overriding existing entry
|
||||
model.tp_plan.update({"model.layers.*.self_attn.q_proj": "colwise_rep"})
|
||||
expected_plan = {
|
||||
"model.layers.*.self_attn.q_proj": "colwise_rep",
|
||||
"model.layers.*.self_attn.k_proj": "colwise",
|
||||
}
|
||||
self.assertEqual(model.tp_plan, expected_plan)
|
||||
|
||||
def test_tp_plan_validation_invalid_style(self):
|
||||
"""Test that invalid parallel styles are rejected."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model_id = "JackFram/llama-68m"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
|
||||
|
||||
# Test invalid parallel style
|
||||
with self.assertRaises(ValueError) as context:
|
||||
model.tp_plan = {"layers.*.self_attn.q_proj": "invalid_style"}
|
||||
|
||||
self.assertIn("Unsupported tensor parallel style 'invalid_style'", str(context.exception))
|
||||
self.assertIn("Supported styles are", str(context.exception))
|
||||
|
||||
def test_tp_plan_validation_nonexistent_layer_warning(self):
|
||||
"""Test that warnings are issued for non-existent layer patterns."""
|
||||
import warnings
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model_id = "JackFram/llama-68m"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
|
||||
|
||||
# Test warning for non-existent layer pattern
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
model.tp_plan = {"nonexistent.*.layer": "colwise"}
|
||||
|
||||
# Check that a warning was issued
|
||||
self.assertTrue(len(w) > 0)
|
||||
warning_message = str(w[0].message)
|
||||
self.assertIn("Layer pattern 'nonexistent.*.layer' does not match any parameters", warning_message)
|
||||
|
||||
def test_tp_plan_valid_layer_patterns(self):
|
||||
"""Test that valid layer patterns are accepted without warnings."""
|
||||
import warnings
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model_id = "JackFram/llama-68m"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
|
||||
|
||||
# Test valid layer patterns that should match the model structure
|
||||
valid_plans = [
|
||||
{"model.layers.*.self_attn.q_proj": "colwise"},
|
||||
{"model.layers.*.self_attn.k_proj": "rowwise"},
|
||||
{"model.layers.*.mlp.gate_proj": "colwise_rep"},
|
||||
]
|
||||
|
||||
for plan in valid_plans:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
model.tp_plan = plan
|
||||
|
||||
# Filter out any warnings that are not about layer patterns
|
||||
layer_warnings = [
|
||||
warning
|
||||
for warning in w
|
||||
if "Layer pattern" in str(warning.message)
|
||||
and "does not match any parameters" in str(warning.message)
|
||||
]
|
||||
|
||||
# Should not have layer pattern warnings for valid patterns
|
||||
self.assertEqual(
|
||||
len(layer_warnings),
|
||||
0,
|
||||
f"Unexpected warning for valid pattern {plan}: {[str(w.message) for w in layer_warnings]}",
|
||||
)
|
||||
|
||||
# Verify the final plan was set correctly
|
||||
self.assertEqual(model.tp_plan, valid_plans[-1])
|
||||
|
||||
def test_tp_plan_none_handling(self):
|
||||
"""Test that None values are handled correctly."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model_id = "JackFram/llama-68m"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
|
||||
|
||||
# Test setting None
|
||||
model.tp_plan = None
|
||||
self.assertEqual(model.tp_plan, {})
|
||||
|
||||
# Test setting a plan after None
|
||||
model.tp_plan = {"model.layers.*.self_attn.q_proj": "colwise"}
|
||||
self.assertEqual(model.tp_plan, {"model.layers.*.self_attn.q_proj": "colwise"})
|
||||
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
class TestTensorParallelAccelerator(TestTensorParallel):
|
||||
nproc_per_node = backend_device_count(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user