Add ep (#39501)
* EP + updates Co-authored-by: Nouamane Tazi <NouamaneTazi@users.noreply.github.com> Co-authored-by: drbh <drbh@users.noreply.github.com> * remove unrelated change * not working yet but let's see where it goes! * update the api a bit * udpate * where I am at for now * fix ep * refactor the API * yups * fix * fixup * clean modeling * just support llama4 for now! * properly avoid * fix * nits * Update src/transformers/models/llama4/modeling_llama4.py * Update src/transformers/integrations/tensor_parallel.py * style * ,,,, * update --------- Co-authored-by: Nouamane Tazi <NouamaneTazi@users.noreply.github.com> Co-authored-by: drbh <drbh@users.noreply.github.com>
This commit is contained in:
@@ -63,8 +63,8 @@ from .integrations.flex_attention import flex_attention_forward
|
||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||
from .integrations.sdpa_paged import sdpa_attention_paged_forward
|
||||
from .integrations.tensor_parallel import (
|
||||
ALL_PARALLEL_STYLES,
|
||||
_get_parameter_tp_plan,
|
||||
distribute_model,
|
||||
initialize_tensor_parallelism,
|
||||
repack_weights,
|
||||
replace_state_dict_local_with_dtensor,
|
||||
@@ -2218,6 +2218,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
"""
|
||||
A method executed at the end of each Transformer model initialization, to execute code that needs the model's
|
||||
modules properly initialized (such as weight initialization).
|
||||
|
||||
This is also used when the user is running distributed code. We add hooks to the modules here, according to
|
||||
the model's tp_plan!
|
||||
"""
|
||||
self.init_weights()
|
||||
self._backward_compatibility_gradient_checkpointing()
|
||||
@@ -2250,17 +2253,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
|
||||
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
||||
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
|
||||
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
|
||||
for name, module in self.named_children():
|
||||
if plan := getattr(module, "_tp_plan", None):
|
||||
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
|
||||
|
||||
if self._tp_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
||||
for v in self._tp_plan.values():
|
||||
if v not in ALL_PARALLEL_STYLES:
|
||||
raise ValueError(
|
||||
f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}"
|
||||
)
|
||||
|
||||
def dequantize(self):
|
||||
"""
|
||||
@@ -4568,6 +4560,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||
load_in_4bit = kwargs.pop("load_in_4bit", False)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
distributed_config = kwargs.pop("distributed_config", None)
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
variant = kwargs.pop("variant", None)
|
||||
@@ -4588,6 +4581,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
):
|
||||
key_mapping = cls._checkpoint_conversion_mapping
|
||||
|
||||
if distributed_config is not None:
|
||||
tp_plan = "auto"
|
||||
|
||||
# Not used anymore -- remove them from the kwargs
|
||||
_ = kwargs.pop("resume_download", None)
|
||||
_ = kwargs.pop("mirror", None)
|
||||
@@ -4619,16 +4615,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
# `device_map` pointing to the correct device
|
||||
if tp_plan is not None:
|
||||
if device_mesh is None:
|
||||
tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None)
|
||||
tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(tp_plan, tp_size=tp_size)
|
||||
else:
|
||||
if "tp" not in device_mesh.mesh_dim_names:
|
||||
raise ValueError(
|
||||
"When using `tp_plan`, the `device_mesh` must contain a 'tp' dimension. "
|
||||
"Please provide a valid `device_mesh`."
|
||||
)
|
||||
device_mesh = device_mesh["tp"]
|
||||
tp_size = device_mesh["tp"].size()
|
||||
device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
|
||||
# TODO: make device_mesh support multiple dimensions
|
||||
if device_mesh.ndim > 1:
|
||||
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
|
||||
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
|
||||
|
||||
if tp_size is None:
|
||||
tp_size = torch.distributed.get_world_size()
|
||||
@@ -4928,23 +4920,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
)
|
||||
|
||||
config.name_or_path = pretrained_model_name_or_path
|
||||
|
||||
# Instantiate model.
|
||||
model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
|
||||
|
||||
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
|
||||
with ContextManagers(model_init_context):
|
||||
# Let's make sure we don't run the init function of buffer modules
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
if _torch_distributed_available and device_mesh is not None:
|
||||
model = distribute_model(model, distributed_config, device_mesh, tp_size)
|
||||
|
||||
# Make sure to tie the weights correctly
|
||||
model.tie_weights()
|
||||
|
||||
# Last check for tp
|
||||
if device_mesh is not None and not model.supports_tp_plan:
|
||||
if config.base_model_tp_plan is None and config.get_text_config().base_model_tp_plan is None:
|
||||
raise NotImplementedError("This model does not have a tensor parallel plan.")
|
||||
|
||||
# make sure we use the model's config since the __init__ call might have copied it
|
||||
config = model.config
|
||||
|
||||
@@ -5025,11 +5012,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
key_mapping=key_mapping,
|
||||
weights_only=weights_only,
|
||||
)
|
||||
|
||||
# record tp degree the model sharded to
|
||||
model._tp_size = tp_size
|
||||
model._device_mesh = device_mesh
|
||||
|
||||
# make sure token embedding weights are still tied if needed
|
||||
model.tie_weights()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user