tp plan should not be NONE (#38255)

* accept custom device_mesh

* fix device_map

* assert that num_heads % tp_size == 0

* todo.

* ReplicateParallel

* handle tied weights

* handle dtensor in save_pretrained with safe_serialization

* tp test works

* doesnt work

* fix shard_and_distribute_module's rank should be local_rank

* tp=4 is correct

* dp+tp is broken

* todo allreduce with dtensors on another dim is annoying

* workaround to sync dp grads when using dtensors

* loading a checkpoint works

* wandb and compare losses with different tp/dp

* cleaning

* cleaning

* .

* .

* logs

* CP2 DP2 no mask works after commenting attn_mask and is_causal from scaled_dot_product_attention

* DP=2 TP=2 now works even with tied embeddings

* model.parameters() and model.module.parameters() are empty..

* reformat sanity_check_tensor_sync

* set atol=1e-4 for CP to pass

* try populate _parameters from named_modules

* refactors
TP2 DP2 works
CP2 DP2 works

* is_causal=True and pack sequences, no attn mask, and preshuffle dataset

* fix packing

* CP=4 doesn't work

* fix labels and position_ids for CP

* DP CP works with transformers 🥳🥳🥳

* refactor

* add example cp

* fixup

* revert sdpa changes

* example cleared

* add CP, DP to the mesh init

* nit

* clean

* use `ALL_PARALLEL_STYLES`

* style

* FSDP works

* log on 1 rank

* .

* fix?

* FSDP1 also has .parameters() bug

* reported gradnorm when using FSDP1 is wrong, but loss is correct so it's okay

* .

* style and fixup

* move stuff around

* fix tests

* style

* let's make it a check

* add missing licences

* warning should be an info

* tp plan should not be NONE

* test all

* god damn it

* test all

---------

Co-authored-by: nouamanetazi <nouamane98@gmail.com>
This commit is contained in:
Arthur
2025-05-21 10:22:38 +02:00
committed by GitHub
parent 711d78d104
commit e288ee00d8
6 changed files with 1539 additions and 133 deletions

View File

@@ -62,8 +62,9 @@ from .integrations.flash_attention import flash_attention_forward
from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.tensor_parallel import (
SUPPORTED_TP_STYLES,
ALL_PARALLEL_STYLES,
_get_parameter_tp_plan,
initialize_tensor_parallelism,
repack_weights,
replace_state_dict_local_with_dtensor,
shard_and_distribute_module,
@@ -797,7 +798,7 @@ def _load_state_dict_into_meta_model(
param_name,
casting_dtype,
to_contiguous,
int(os.environ["RANK"]), # the rank
device_mesh.get_local_rank(),
device_mesh,
)
else:
@@ -1964,9 +1965,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
if self._tp_plan is not None and is_torch_greater_or_equal("2.3"):
for _, v in self._tp_plan.items():
if v not in SUPPORTED_TP_STYLES:
if v not in ALL_PARALLEL_STYLES:
raise ValueError(
f"Unsupported tensor parallel style {v}. Supported styles are {SUPPORTED_TP_STYLES}"
f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}"
)
def dequantize(self):
@@ -3559,6 +3560,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
if safe_serialization:
# TODO: fix safe_serialization for tied weights
# Safetensors does not allow tensor aliasing.
# We're going to remove aliases before saving
ptrs = collections.defaultdict(list)
@@ -4040,6 +4042,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
`torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations.
tp_size (`str`, *optional*):
A torch tensor parallel degree. If not provided would default to world size.
device_mesh (`torch.distributed.DeviceMesh`, *optional*):
A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now.
offload_folder (`str` or `os.PathLike`, *optional*):
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_state_dict (`bool`, *optional*):
@@ -4137,6 +4141,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
gguf_file = kwargs.pop("gguf_file", None)
tp_plan = kwargs.pop("tp_plan", None)
tp_size = kwargs.pop("tp_size", None)
device_mesh = kwargs.pop("device_mesh", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
@@ -4172,59 +4177,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
# `device_map` pointing to the correct device
device_mesh = None
if tp_plan is not None:
if not is_torch_greater_or_equal("2.5"):
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
if not torch.distributed.is_initialized():
try:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if device_type == "cuda":
torch.distributed.init_process_group(
"nccl", rank=rank, world_size=world_size, init_method="env://"
)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
elif device_type == "cpu":
cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo"
torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size)
elif device_type == "xpu":
torch.distributed.init_process_group("ccl", rank=rank, world_size=world_size)
torch.xpu.set_device(int(os.environ["LOCAL_RANK"]))
elif device_type == "hpu":
torch.distributed.init_process_group("hccl", rank=rank, world_size=world_size)
torch.hpu.set_device(int(os.environ["LOCAL_RANK"]))
except Exception as e:
raise EnvironmentError(
"We tried to initialize torch.distributed for you, but it failed, make"
"sure you init torch distributed in your script to use `tp_plan='auto'`"
) from e
# Get device with index assuming equal number of devices per host
if device_type == "xpu":
index = torch.xpu.current_device()
elif device_type == "hpu":
index = torch.hpu.current_device()
if device_mesh is None and tp_plan is not None:
tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None)
else:
index = None if device_type == "cpu" else torch.cuda.current_device()
tp_device = torch.device(device_type, index)
if index is not None and index > 0:
import sys
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
# This is the easiest way to dispatch to the current process device
device_map = tp_device
# Assuming sharding the model onto the world when tp_size not provided
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
# TODO: make device_mesh support multiple dimensions
if device_mesh.ndim == 1:
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
if use_auth_token is not None:
warnings.warn(
@@ -5142,7 +5102,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
name,
casting_dtype,
to_contiguous,
os.environ["RANK"],
device_mesh.get_local_rank(),
device_mesh,
)