From 711d78d104366c16dd99ef11084c27848f7a4133 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Tue, 20 May 2025 22:43:04 +0200 Subject: [PATCH] Revert parallelism temporarily (#38240) * Revert "Protect ParallelInterface" This reverts commit cb513e35f9c096d60558bd43110837cbb66611ce. * Revert "parallelism goes brrr (#37877)" This reverts commit 1c2f36b480e02c9027d2523746d34e27b39e01a4. * Empty commit --- examples/3D_parallel.py | 422 ---------- examples/pytorch/3d_parallel_checks.py | 780 ------------------ examples/pytorch/context_parallel.py | 94 --- setup.py | 4 +- src/transformers/dependency_versions_table.py | 2 +- src/transformers/integrations/__init__.py | 4 +- .../integrations/tensor_parallel.py | 279 ++----- src/transformers/modeling_utils.py | 75 +- 8 files changed, 138 insertions(+), 1522 deletions(-) delete mode 100644 examples/3D_parallel.py delete mode 100644 examples/pytorch/3d_parallel_checks.py delete mode 100644 examples/pytorch/context_parallel.py diff --git a/examples/3D_parallel.py b/examples/3D_parallel.py deleted file mode 100644 index d56e63bc68..0000000000 --- a/examples/3D_parallel.py +++ /dev/null @@ -1,422 +0,0 @@ -""": -This script is used to test training a model using Tensor Parallelism and Data Parallelism. - -Usage: -export CUDA_VISIBLE_DEVICES=0,1,2,3 -export CUDA_VISIBLE_DEVICES=4,5,6,7 -export CUDA_VISIBLE_DEVICES=5,6,7 -TP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 examples/3D_parallel.py -CP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 examples/3D_parallel.py -CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=4 examples/3D_parallel.py -DP_SIZE=2 CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=8 examples/3D_parallel.py - -TP_SIZE=1 CP_SIZE=4 torchrun --nproc_per_node=4 examples/3D_parallel.py -TP_SIZE=1 DP_SIZE=4 torchrun --nproc_per_node=4 examples/3D_parallel.py -TP_SIZE=4 DP_SIZE=1 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 examples/3D_parallel.py -IGNORE_SANITY=1 CP_SIZE=1 TP_SIZE=1 DP_SIZE=1 torchrun --nproc_per_node=1 --rdzv_endpoint=localhost:29504 examples/3D_parallel.py -ocalhost:29504 test_train.py -""" - -import logging -import os -from contextlib import nullcontext -from typing import Iterable - -import torch -import torch.distributed as dist -import torch.distributed.checkpoint as dcp -import torch.optim as optim -import wandb -from datasets import load_dataset -from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict -from torch.distributed.checkpoint.stateful import Stateful -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import ShardingStrategy -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.experimental import context_parallel -from torch.nn.attention import SDPBackend, sdpa_kernel -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -from transformers import AutoModelForCausalLM, AutoTokenizer - - -# torch.use_deterministic_algorithms(True) -torch.backends.cudnn.deterministic = True - -# Set up logging -logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, -) -logger = logging.getLogger(__name__) - -# from torch.distributed.tensor.experimental._attention import set_rotate_method - -# set_rotate_method("alltoall") # CP rotate shards using all-to-all - - -def main(): - tp_size = int(os.environ.get("TP_SIZE", 1)) - dp_size = int(os.environ.get("DP_SIZE", 1)) - cp_size = int(os.environ.get("CP_SIZE", 1)) # Add CP size configuration - sdpa_backend = SDPBackend.FLASH_ATTENTION # For CP - # sdpa_backend = SDPBackend.MATH # For CP - global_batch_size = 8 # Desired global batch size - seq_len = 1024 # Sequence length - num_train_steps = 10000 # Number of training steps - LR = 1e-5 - model_name = "HuggingFaceTB/SmolLM2-1.7B" - # model_name = "unsloth/Llama-3.2-1B" - - CHECKPOINT_DIR = f"checkpoint_tp{tp_size}_dp{dp_size}_cp{cp_size}" - - # Initialize distributed environment - if "RANK" in os.environ and "WORLD_SIZE" in os.environ: - dist.init_process_group("nccl") - rank = dist.get_rank() - world_size = dist.get_world_size() - local_rank = int(os.environ["LOCAL_RANK"]) - torch.cuda.set_device(local_rank) - - assert world_size == tp_size * dp_size * cp_size, ( - f"World size ({world_size}) must equal TP size ({tp_size}) * DP size ({dp_size}) * CP size ({cp_size})" - ) - - mesh = torch.arange(world_size).reshape(dp_size, tp_size, cp_size) - world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("dp", "tp", "cp")) - tp_mesh = world_mesh["tp"] - dp_mesh = world_mesh["dp"] - cp_mesh = world_mesh["cp"] - world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp") - logger.info(f"Created DeviceMesh: {world_mesh}") - logger.info( - f"Distributed setup - Rank: {rank}, World size: {world_size}, Local rank: {local_rank}, DP: {dp_mesh.get_local_rank()}, TP: {tp_mesh.get_local_rank()}, CP: {cp_mesh.get_local_rank()}" - ) - - if dist.get_rank() == 0: - wandb.init( - project="tp_dp_test", - config={ - "tp_size": tp_size, - "dp_size": dp_size, - "cp_size": cp_size, - "global_batch_size": global_batch_size, - "model_name": model_name, - "dataset": "roneneldan/TinyStories-1M", - "seq_len": seq_len, - "lr": LR, - "weight_decay": 0.1, - }, - name=f"llama_tp{tp_size}_dp{dp_size}_cp{cp_size}" - if model_name == "unsloth/Llama-3.2-1B" - else f"tp{tp_size}_dp{dp_size}_cp{cp_size}", - ) - logger.info("Wandb initialized.") - # Log the current file to wandb - wandb.save("test_train.py") - - # Load model and tokenizer - logger.info(f"Loading model and tokenizer from {model_name}") - tokenizer = AutoTokenizer.from_pretrained(model_name) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - logger.info(f"Set pad_token to eos_token: {tokenizer.pad_token}") - - model = AutoModelForCausalLM.from_pretrained( - model_name, - device_mesh=tp_mesh if dist.is_initialized() else None, - tp_plan="auto", - torch_dtype=torch.bfloat16, - ) - logger.info(f"Model loaded onto device mesh: {tp_mesh}") - device = torch.device(f"cuda:{local_rank}") - logger.info(f"Using device: {device} for non-model tensors") - use_ddp = False - if dist.is_initialized() and dp_mesh.size() > 1: - model = FSDP(model, device_mesh=dp_mesh, sharding_strategy=ShardingStrategy.NO_SHARD) - use_ddp = True - pass - - model.train() - - logger.info("Loading TinyStories dataset...") - raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") # Use 1% for faster testing - - def tokenize_function(examples): - # Tokenize the text without padding - tokenized_batch = tokenizer( - examples["text"], padding=False, truncation=True, max_length=seq_len, return_tensors=None - ) - # Set labels to be the same as input_ids for Causal LM - tokenized_batch["labels"] = tokenized_batch["input_ids"].copy() - return tokenized_batch - - tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"]) - logger.info(f"Dataset loaded and tokenized. Size: {len(tokenized_dataset)}") - - # Create packed sequences - def create_packed_sequences(examples): - # Flatten all sequences - all_tokens = [] - for input_ids in examples["input_ids"]: - all_tokens.extend(input_ids) - - # Split into sequences of seq_len + 1 (for input + label) - num_sequences = len(all_tokens) // (seq_len + 1) - packed_input_ids = [] - packed_labels = [] - - for i in range(num_sequences): - start_idx = i * (seq_len + 1) - end_idx = start_idx + (seq_len + 1) - # Get the full sequence - full_sequence = all_tokens[start_idx:end_idx] - # For input_ids, remove the last token - packed_input_ids.append(full_sequence[:-1]) - # For labels, remove the first token - packed_labels.append(full_sequence[1:]) - - return {"input_ids": packed_input_ids, "labels": packed_labels} - - # Apply packing to the dataset - packed_dataset = tokenized_dataset.map( - create_packed_sequences, - batched=True, - remove_columns=tokenized_dataset.column_names, - batch_size=1000, # Process in batches for efficiency - num_proc=60, - ) - logger.info(f"Dataset packed. New size: {len(packed_dataset)}") - - # Shuffle the packed dataset - packed_dataset = packed_dataset.shuffle(seed=42) - logger.info("Packed dataset shuffled") - - # Calculate local batch size - if dist.is_initialized(): - assert global_batch_size % dp_mesh.size() == 0, ( - f"Global batch size ({global_batch_size}) must be divisible by DP size ({dp_mesh.size()})" - ) - local_batch_size = global_batch_size // dp_mesh.size() - else: - local_batch_size = global_batch_size - - logger.info( - f"Global batch size: {global_batch_size}, DP size: {dp_size if dist.is_initialized() else 1}, Local batch size: {local_batch_size}" - ) - - # Simple collate function since sequences are already packed - def collate_fn(batch): - input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long) - labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long) - return {"input_ids": input_ids, "labels": labels} - - if dist.is_initialized(): - sampler = DistributedSampler( - packed_dataset, num_replicas=dp_mesh.size(), rank=dp_mesh.get_local_rank(), shuffle=False - ) - else: - sampler = None - - dataloader = DataLoader( - packed_dataset, - batch_size=local_batch_size, - sampler=sampler, - shuffle=False, - collate_fn=collate_fn, - pin_memory=True, - ) - logger.info(f"DataLoader created. Distributed: {dist.is_initialized()}") - - optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.1) - - # Training loop - logger.info(f"Starting training for {num_train_steps} steps...") - model.train() - step = 0 - while step < num_train_steps: - for batch in dataloader: - if step >= num_train_steps: - break # Exit loop if max steps reached - - # Move batch to appropriate device - batch = {k: v.to(device) for k, v in batch.items()} - optimizer.zero_grad() - - # Add position_ids to batch before CP sharding - batch_size = batch["input_ids"].shape[0] - position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) - batch["position_ids"] = position_ids - from torch.distributed.tensor.experimental._attention import _cp_options - - _cp_options.enable_load_balance = False - - with sdpa_kernel(sdpa_backend): # TODO: ideally move this to attention implementation - cp_context = ( - nullcontext() - if cp_mesh.size() == 1 - else context_parallel( - cp_mesh, - buffers=[ - batch["input_ids"], - batch["labels"], - batch["position_ids"], - ], - buffer_seq_dims=[1, 1, 1], - ) - ) - with cp_context: - # Pop labels from batch before model forward pass - labels = batch.pop("labels") - outputs = model(**batch) # [mbs, seq_len/cp] - loss = outputs.loss - logits = outputs.logits - - # Compute loss with shifted labels - loss = model.loss_function( - logits=logits, labels=None, shift_labels=labels, vocab_size=model.config.vocab_size - ) - loss.backward() - - # all reduce grads across dp_cp if applicable - all_reduce_grads(model, world_mesh, use_ddp=use_ddp) - - if hasattr(model, "clip_grad_norm_"): - gradnorm = model.clip_grad_norm_(max_norm=1.0, norm_type=2.0) # TODO: fix reported gradnorm - else: - # only works with FSDP's NO_SHARD otherwise we should use FSDP's clip_grad_norm_ - assert len(list(model.parameters())) > 5, "No parameters found in model. Probably DDP bug.." - gradnorm = clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=2.0, foreach=True) - - optimizer.step() - # allreduce loss across cp_dp before logging - if dist.is_initialized() and (cp_mesh.size() > 1 or dp_mesh.size() > 1): - dist.all_reduce(loss, group=world_mesh["dp_cp"].get_group(), op=dist.ReduceOp.AVG) - current_loss = loss.item() - - # Log loss and gradnorm to wandb (only on rank 0 of dp group) - if not dist.is_initialized() or dist.get_rank() == 0: - logger.info( - f"Step: {step} | GBS: {global_batch_size} | DP: {dp_mesh.size()} | TP: {tp_mesh.size()} | CP: {cp_mesh.size()} | Loss: {current_loss} | Gradnorm: {gradnorm} | lr: {LR}" - ) - wandb.log( - { - "train/loss": current_loss, - "train/gradnorm": gradnorm, - "step": step, - "lr": LR, - "GBS": global_batch_size, - } - ) - - step += 1 # Increment step count - - logger.info("Training loop finished.") - - # Save model using DCP (only if distributed) - if dist.is_initialized(): - state_dict = {"app": AppState(model, optimizer)} - dcp.save( - state_dict=state_dict, - checkpoint_id=CHECKPOINT_DIR, - ) - logger.info(f"Saved checkpoint to {CHECKPOINT_DIR}") - else: - # Fallback to regular save for non-distributed case - save_dir = "test_model_nondist" - model.save_pretrained(save_dir, safe_serialization=False) - tokenizer.save_pretrained(save_dir) # Save tokenizer too - logger.info(f"Saved model to {save_dir}") - - dist.destroy_process_group() - logger.info("Cleaned up distributed process group") - # Finish wandb run on rank 0 - if dist.get_rank() == 0: - wandb.finish() - logger.info("Wandb run finished.") - - -def all_reduce_grads(model, world_mesh, use_ddp): - """All reduce gradients across dp_cp if applicable.""" - cp_mesh = world_mesh["cp"] - if use_ddp: - # DDP/FSDP takes care of syncing grads - mesh = cp_mesh - else: - mesh = world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp") - if dist.is_initialized() and mesh.size() > 1: - for name, param in model.named_parameters(): - if param.grad is not None: - # Workaround for cross-mesh communication limitation with DTensor gradients - if isinstance(param.grad, DTensor): - local_grad = param.grad.to_local() - # Ensure grad requires grad for inplace modification checks (might not be needed) - # local_grad = local_grad.detach().requires_grad_(True) - torch.distributed.all_reduce(local_grad, op=torch.distributed.ReduceOp.SUM, group=mesh.get_group()) - local_grad = local_grad / mesh.size() - # Assign averaged grad back - need careful handling if DTensor structure is complex - # This simple assignment might work if the grad structure matches param structure - param.grad = DTensor.from_local( - local_grad, device_mesh=param.grad.device_mesh, placements=param.grad.placements - ) - else: - # Handle regular tensors if any exist (e.g. buffers not converted to DTensor) - torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG, group=mesh.get_group()) - - -class AppState(Stateful): - """Wrapper for checkpointing the Application State including model and optimizer.""" - - def __init__(self, model, optimizer=None): - self.model = model - self.optimizer = optimizer - - def state_dict(self): - model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) - return {"model": model_state_dict, "optim": optimizer_state_dict} - - def load_state_dict(self, state_dict): - set_state_dict( - self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"] - ) - - -def clip_grad_norm_( - parameters: Iterable[torch.Tensor], - max_norm: float, - norm_type: float = 2.0, - error_if_nonfinite: bool = False, - foreach: bool | None = None, -) -> torch.Tensor: - """ - Clip the gradient norm of an iterable of parameters. - """ - # Filter out parameters with no gradients - parameters = [p for p in parameters if p.grad is not None] - assert len(parameters) > 0, "No parameters with gradients found" - - # Calculate total norm - if norm_type == float("inf"): - total_norm = max(p.grad.detach().abs().max() for p in parameters) - else: - total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type) - - # Convert DTensor to local tensor if needed - if isinstance(total_norm, DTensor): - total_norm = total_norm.full_tensor() - - # Clip gradients - clip_coef = max_norm / (total_norm + 1e-6) - if clip_coef < 1: - for p in parameters: - p.grad.detach().mul_(clip_coef) - - return total_norm - - -if __name__ == "__main__": - main() diff --git a/examples/pytorch/3d_parallel_checks.py b/examples/pytorch/3d_parallel_checks.py deleted file mode 100644 index 1c7e88e5e4..0000000000 --- a/examples/pytorch/3d_parallel_checks.py +++ /dev/null @@ -1,780 +0,0 @@ -""": -This script is used to test training a model using Tensor Parallelism and Data Parallelism. - -Usage: -export CUDA_VISIBLE_DEVICES=0,1,2,3 -export CUDA_VISIBLE_DEVICES=4,5,6,7 -export CUDA_VISIBLE_DEVICES=5,6,7 -TP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 test_train.py -CP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 test_train.py -CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=4 test_train.py - -TP_SIZE=1 CP_SIZE=4 torchrun --nproc_per_node=4 test_train.py -TP_SIZE=1 DP_SIZE=4 torchrun --nproc_per_node=4 test_train.py -TP_SIZE=4 DP_SIZE=1 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 test_train.py -IGNORE_SANITY=1 CP_SIZE=1 TP_SIZE=1 DP_SIZE=1 torchrun --nproc_per_node=1 --rdzv_endpoint=l -ocalhost:29504 test_train.py -""" - -import logging -import os -from contextlib import nullcontext -from typing import Dict, Iterable, Optional - -import torch -import torch.distributed as dist -import torch.distributed.checkpoint as dcp -import torch.nn as nn -import torch.optim as optim -import wandb -from datasets import load_dataset -from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict -from torch.distributed.checkpoint.stateful import Stateful -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import ShardingStrategy -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.experimental import context_parallel -from torch.nn.attention import SDPBackend, sdpa_kernel -from torch.utils.data import DataLoader, default_collate -from torch.utils.data.distributed import DistributedSampler - -from transformers import AutoModelForCausalLM, AutoTokenizer - - -ignore_sanity_checks = int(os.environ.get("IGNORE_SANITY", 0)) == 1 -# torch.use_deterministic_algorithms(True) -torch.backends.cudnn.deterministic = True - -# Set up logging -logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, -) -logger = logging.getLogger(__name__) - -# from torch.distributed.tensor.experimental._attention import set_rotate_method - -# set_rotate_method("alltoall") # rotate shards using all-to-all - - -def main(): - tp_size = int(os.environ.get("TP_SIZE", 1)) - dp_size = int(os.environ.get("DP_SIZE", 4)) - cp_size = int(os.environ.get("CP_SIZE", 1)) # Add CP size configuration - sdpa_backend = SDPBackend.FLASH_ATTENTION # For CP - # sdpa_backend = SDPBackend.MATH # For CP - global_batch_size = 8 # Desired global batch size - seq_len = 1024 # Sequence length - num_train_steps = 10000 # Number of training steps - LR = 1e-5 - model_name = "HuggingFaceTB/SmolLM2-1.7B" - # model_name = "unsloth/Llama-3.2-1B" - - CHECKPOINT_DIR = f"checkpoint_tp{tp_size}_dp{dp_size}_cp{cp_size}" - - # Initialize distributed environment - if "RANK" in os.environ and "WORLD_SIZE" in os.environ: - dist.init_process_group("nccl") - rank = dist.get_rank() - world_size = dist.get_world_size() - local_rank = int(os.environ["LOCAL_RANK"]) - torch.cuda.set_device(local_rank) - - assert world_size == tp_size * dp_size * cp_size, ( - f"World size ({world_size}) must equal TP size ({tp_size}) * DP size ({dp_size}) * CP size ({cp_size})" - ) - - mesh = torch.arange(world_size).reshape(dp_size, tp_size, cp_size) - world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("dp", "tp", "cp")) - tp_mesh = world_mesh["tp"] - dp_mesh = world_mesh["dp"] - cp_mesh = world_mesh["cp"] - world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp") - logger.info(f"Created DeviceMesh: {world_mesh}") - logger.info( - f"Distributed setup - Rank: {rank}, World size: {world_size}, Local rank: {local_rank}, DP: {dp_mesh.get_local_rank()}, TP: {tp_mesh.get_local_rank()}, CP: {cp_mesh.get_local_rank()}" - ) - - if dist.get_rank() == 0: - wandb.init( - project="tp_dp_test", - config={ - "tp_size": tp_size, - "dp_size": dp_size, - "cp_size": cp_size, - "global_batch_size": global_batch_size, - "model_name": model_name, - "dataset": "roneneldan/TinyStories-1M", - "seq_len": seq_len, - "lr": LR, - "weight_decay": 0.1, - }, - name=f"llama_tp{tp_size}_dp{dp_size}_cp{cp_size}" - if model_name == "unsloth/Llama-3.2-1B" - else f"tp{tp_size}_dp{dp_size}_cp{cp_size}", - ) - logger.info(f"ignore_sanity_checks is set to: {ignore_sanity_checks}") - logger.info("Wandb initialized.") - # Log the current file to wandb - wandb.save("test_train.py") - - else: - logger.info("Running in non-distributed mode. DeviceMesh not applicable.") - rank = 0 - world_size = 1 - local_rank = 0 - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - wandb.init( - project="tp_dp_test", - config={ - "tp_size": 1, - "dp_size": 1, - "global_batch_size": global_batch_size, - "model_name": model_name, - "dataset": "roneneldan/TinyStories-1M", - "seq_len": seq_len, - }, - name="llama_tp1_dp1_nondist" if model_name == "unsloth/Llama-3.2-1B" else "tp1_dp1_nondist", - ) - logger.info("Wandb initialized for non-distributed run.") - - # Load model and tokenizer - logger.info(f"Loading model and tokenizer from {model_name}") - tokenizer = AutoTokenizer.from_pretrained(model_name) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - logger.info(f"Set pad_token to eos_token: {tokenizer.pad_token}") - - model = AutoModelForCausalLM.from_pretrained( - model_name, - device_mesh=tp_mesh if dist.is_initialized() else None, - tp_plan="auto", - torch_dtype=torch.bfloat16, - ) - logger.info(f"Model loaded onto device mesh: {tp_mesh}") - - if dist.is_initialized(): - assert model.config.num_key_value_heads % tp_mesh.size() == 0, ( - f"num_key_value_heads={model.config.num_key_value_heads} must be divisible by tp_size={tp_mesh.size()}" - ) - device = torch.device(f"cuda:{local_rank}") - else: - model = model.to(device) - - logger.info(f"Using device: {device} for non-model tensors") - use_ddp = False - if dist.is_initialized() and dp_mesh.size() > 1: - # FSDP1 - model = FSDP(model, device_mesh=dp_mesh, sharding_strategy=ShardingStrategy.NO_SHARD) - # FSDP2 - # for transformer_block in model.model.layers: - # fully_shard(transformer_block, mesh=dp_mesh, reshard_after_forward=False) - # fully_shard(model.model, mesh=dp_mesh, reshard_after_forward=False) - # DDP - # replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) - # assert len(list(model.parameters()))>5, "No parameters found in model. Probably DDP/FSDP bug.." # TODO: we should be cautious abt using model.parameters() - use_ddp = True - - model.train() - assert len(list(model.parameters())) > 0, "No parameters found in model. Probably DDP bug.." - assert len([p for p in model.parameters() if p.requires_grad]) > 0, ( - "No gradients found in model. Probably DDP bug.." - ) - - if dist.is_initialized() and not ignore_sanity_checks: - # assert model is replicated across all dp - for name, param in model.named_parameters(): - sanity_check_tensor_sync(param, dp_mesh) - - # assert model is different across tp (only for sharded params) - for name, param in model.named_parameters(): - if isinstance(param, DTensor) and param.placements[0].is_shard(): - # Only check sharded parameters for non-sync across TP - sanity_check_tensor_sync(param, tp_mesh, not_sync=True) - elif isinstance(param, DTensor) and param.placements[0].is_replicate(): - # Replicated parameters should be the same across TP - sanity_check_tensor_sync(param, tp_mesh) - - # assert model is replicated across cp - for name, param in model.named_parameters(): - sanity_check_tensor_sync(param, cp_mesh) - - # Load and preprocess TinyStories dataset - logger.info("Loading TinyStories dataset...") - raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") # Use 1% for faster testing - - def tokenize_function(examples): - # Tokenize the text without padding - tokenized_batch = tokenizer( - examples["text"], padding=False, truncation=True, max_length=seq_len, return_tensors=None - ) - # Set labels to be the same as input_ids for Causal LM - tokenized_batch["labels"] = tokenized_batch["input_ids"].copy() - return tokenized_batch - - tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"]) - logger.info(f"Dataset loaded and tokenized. Size: {len(tokenized_dataset)}") - - # Create packed sequences - def create_packed_sequences(examples): - # Flatten all sequences - all_tokens = [] - for input_ids in examples["input_ids"]: - all_tokens.extend(input_ids) - - # Split into sequences of seq_len + 1 (for input + label) - num_sequences = len(all_tokens) // (seq_len + 1) - packed_input_ids = [] - packed_labels = [] - - for i in range(num_sequences): - start_idx = i * (seq_len + 1) - end_idx = start_idx + (seq_len + 1) - # Get the full sequence - full_sequence = all_tokens[start_idx:end_idx] - # For input_ids, remove the last token - packed_input_ids.append(full_sequence[:-1]) - # For labels, remove the first token - packed_labels.append(full_sequence[1:]) - - return {"input_ids": packed_input_ids, "labels": packed_labels} - - # Apply packing to the dataset - packed_dataset = tokenized_dataset.map( - create_packed_sequences, - batched=True, - remove_columns=tokenized_dataset.column_names, - batch_size=1000, # Process in batches for efficiency - num_proc=60, - ) - logger.info(f"Dataset packed. New size: {len(packed_dataset)}") - - # Shuffle the packed dataset - packed_dataset = packed_dataset.shuffle(seed=42) - logger.info("Packed dataset shuffled") - - # Calculate local batch size - if dist.is_initialized(): - assert global_batch_size % dp_mesh.size() == 0, ( - f"Global batch size ({global_batch_size}) must be divisible by DP size ({dp_mesh.size()})" - ) - local_batch_size = global_batch_size // dp_mesh.size() - else: - local_batch_size = global_batch_size - - logger.info( - f"Global batch size: {global_batch_size}, DP size: {dp_size if dist.is_initialized() else 1}, Local batch size: {local_batch_size}" - ) - - # Simple collate function since sequences are already packed - def collate_fn(batch): - input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long) - labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long) - return {"input_ids": input_ids, "labels": labels} - - if dist.is_initialized(): - sampler = DistributedSampler( - packed_dataset, num_replicas=dp_mesh.size(), rank=dp_mesh.get_local_rank(), shuffle=False - ) - else: - sampler = None - - dataloader = DataLoader( - packed_dataset, - batch_size=local_batch_size, - sampler=sampler, - shuffle=False, - collate_fn=collate_fn, - ) - logger.info(f"DataLoader created. Distributed: {dist.is_initialized()}") - - optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.1) - - # Training loop - logger.info(f"Starting training for {num_train_steps} steps...") - model.train() - step = 0 - while step < num_train_steps: - for batch in dataloader: - if step >= num_train_steps: - break # Exit loop if max steps reached - - # Move batch to appropriate device - batch = {k: v.to(device) for k, v in batch.items()} - - # Sanity checks for batch distribution (only if distributed) - if dist.is_initialized() and not ignore_sanity_checks: - # check batch is same across all tp - sanity_check_tensor_sync(batch["input_ids"], tp_mesh) - # check batch is different across dp - sanity_check_tensor_sync(batch["input_ids"], dp_mesh, not_sync=True) - - optimizer.zero_grad() - - # Add position_ids to batch before CP sharding - batch_size = batch["input_ids"].shape[0] - position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) - batch["position_ids"] = position_ids - from torch.distributed.tensor.experimental._attention import _cp_options - - _cp_options.enable_load_balance = False - - with sdpa_kernel(sdpa_backend): # TODO: ideally move this to attention implementation - cp_context = ( - nullcontext() - if cp_mesh.size() == 1 - else context_parallel( - cp_mesh, - buffers=[ - batch["input_ids"], - batch["labels"], - batch["position_ids"], - ], # TODO: need to add attention mask - buffer_seq_dims=[1, 1, 1], - ) - ) - with cp_context: - # Pop labels from batch before model forward pass - labels = batch.pop("labels") - outputs = model(**batch) # [mbs, seq_len/cp] - loss = outputs.loss - logits = outputs.logits - - # Compute loss with shifted labels - loss = model.loss_function( - logits=logits, labels=None, shift_labels=labels, vocab_size=model.config.vocab_size - ) - - # Sanity checks for logits - if dist.is_initialized() and not ignore_sanity_checks: - # sanity_check_tensor_sync(logits, tp_mesh) # TODO: only true without sequence parallel - sanity_check_tensor_sync(logits, dp_mesh, not_sync=True) - sanity_check_tensor_sync(logits, cp_mesh, not_sync=True) - - loss.backward() - - # all reduce grads across dp_cp if applicable - all_reduce_grads(model, world_mesh, use_ddp=use_ddp) - - # Sanity checks for gradients (only if distributed) - if dist.is_initialized() and not ignore_sanity_checks: - # check grads are not same across all tp (for sharded grads) - for name, param in model.named_parameters(): - if param.grad is not None and isinstance(param.grad, DTensor): - if param.grad.placements[0].is_shard(): - sanity_check_tensor_sync(param.grad, tp_mesh, not_sync=True) - elif param.grad.placements[0].is_replicate(): - sanity_check_tensor_sync(param.grad, tp_mesh) - # check grads are same across dp - for name, param in model.named_parameters(): - if param.grad is not None and dp_mesh.size() > 1: - sanity_check_tensor_sync(param.grad, dp_mesh) - # check grads are same across cp - for name, param in model.named_parameters(): - if param.grad is not None and cp_mesh.size() > 1: - sanity_check_tensor_sync(param.grad, cp_mesh) - - # Calculate gradient norm and clip gradients - if hasattr(model, "clip_grad_norm_"): - # when using FSDP or DDP, model.parameters() doesn't work - gradnorm = model.clip_grad_norm_(max_norm=1.0, norm_type=2.0) - else: - assert len(list(model.parameters())) > 2, "No parameters found in model. Probably DDP bug.." - assert len([p for p in model.parameters() if p.requires_grad]) > 2, ( - "No gradients found in model. Probably DDP bug.." - ) - assert len([p for p in model.parameters() if p.grad is not None]) > 2, ( - "No gradients found in model. Probably DDP bug.." - ) - # only works with FSDP's NO_SHARD otherwise we should use FSDP's clip_grad_norm_ - gradnorm = clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=2.0, foreach=True) - - optimizer.step() - # Sanity checks for updated model parameters (only if distributed) - if dist.is_initialized() and not ignore_sanity_checks: - # check updated model is different across all tp (for sharded params) - for name, param in model.named_parameters(): - if isinstance(param, DTensor): - if param.placements[0].is_shard(): - sanity_check_tensor_sync(param, tp_mesh, not_sync=True) - elif param.placements[0].is_replicate(): - sanity_check_tensor_sync(param, tp_mesh) - # check updated model is same across dp - for name, param in model.named_parameters(): - sanity_check_tensor_sync(param, dp_mesh) - # check updated model is same across cp - for name, param in model.named_parameters(): - sanity_check_tensor_sync(param, cp_mesh) - - # allreduce loss across cp_dp before logging - if dist.is_initialized() and (cp_mesh.size() > 1 or dp_mesh.size() > 1): - dist.all_reduce(loss, group=world_mesh["dp_cp"].get_group(), op=dist.ReduceOp.AVG) - current_loss = loss.item() - - # Log loss and gradnorm to wandb (only on rank 0 of dp group) - if not dist.is_initialized() or dist.get_rank() == 0: - logger.info( - f"Step: {step} | GBS: {global_batch_size} | DP: {dp_mesh.size()} | TP: {tp_mesh.size()} | CP: {cp_mesh.size()} | Loss: {current_loss} | Gradnorm: {gradnorm} | lr: {LR}" - ) - wandb.log( - { - "train/loss": current_loss, - "train/gradnorm": gradnorm, - "step": step, - "lr": LR, - "GBS": global_batch_size, - } - ) - - step += 1 # Increment step count - - logger.info("Training loop finished.") - - # Save model using DCP (only if distributed) - if dist.is_initialized(): - state_dict = {"app": AppState(model, optimizer)} - dcp.save( - state_dict=state_dict, - checkpoint_id=CHECKPOINT_DIR, - ) - logger.info(f"Saved checkpoint to {CHECKPOINT_DIR}") - else: - # Fallback to regular save for non-distributed case - save_dir = "test_model_nondist" - model.save_pretrained(save_dir, safe_serialization=False) - tokenizer.save_pretrained(save_dir) # Save tokenizer too - logger.info(f"Saved model to {save_dir}") - - # Example of loading the checkpoint (only if distributed) - if dist.is_initialized(): - # Create a new model instance - logger.info("Creating new model instance for verification") - new_model = AutoModelForCausalLM.from_pretrained( - model_name, - device_mesh=tp_mesh, - torch_dtype=torch.bfloat16, # Use same dtype - ) - new_optimizer = optim.AdamW(new_model.parameters(), lr=LR) - - # Load checkpoint into new model - state_dict = {"app": AppState(new_model, new_optimizer)} - dcp.load( - state_dict=state_dict, - checkpoint_id=CHECKPOINT_DIR, - ) - logger.info("Loaded checkpoint into new model") - - # Verify model weights match - logger.info("Verifying model weights match...") - for (name1, param1), (name2, param2) in zip(model.named_parameters(), new_model.named_parameters()): - torch.testing.assert_close( - param1.to_local(), - param2.to_local(), - rtol=1e-3, - atol=1e-3, - msg=f"Weights mismatch in {name1} vs {name2}", - ) - - # Verify optimizer states match - logger.info("Verifying optimizer states match...") - for name1, state1 in optimizer.state_dict().items(): - state2 = new_optimizer.state_dict()[name1] - if name1 == "state": - # Compare state dictionaries for each parameter - for param_id, param_state1 in state1.items(): - param_state2 = state2[param_id] - # Compare each state component (step, exp_avg, exp_avg_sq) - for key, value1 in param_state1.items(): - value2 = param_state2[key] - if isinstance(value1, DTensor): - # Convert DTensors to local tensors for comparison - torch.testing.assert_close( - value1.to_local(), - value2.to_local(), - rtol=1e-5, - atol=1e-5, - msg=f"Optimizer state mismatch in state[{param_id}][{key}]", - ) - else: - torch.testing.assert_close( - value1, - value2, - rtol=1e-5, - atol=1e-5, - msg=f"Optimizer state mismatch in state[{param_id}][{key}]", - ) - elif name1 == "param_groups": - # Compare param_groups (excluding the actual params list) - for i, (group1, group2) in enumerate(zip(state1, state2)): - for key in group1: - if key != "params": # Skip comparing the params list - assert group1[key] == group2[key], f"Param group mismatch in param_groups[{i}][{key}]" - - # Run a forward pass with both models to verify outputs match - logger.info("Running forward pass verification...") - with torch.no_grad(): - # Use the last batch for verification - batch = {k: v.to(device) for k, v in batch.items()} # Ensure batch is on correct device - original_outputs = model(**batch) - new_outputs = new_model(**batch) - torch.testing.assert_close( - original_outputs.logits.to_local(), - new_outputs.logits.to_local(), - rtol=1e-3, - atol=1e-3, - msg="Model outputs do not match!", - ) # Increased tolerance slightly for bf16 - - # Clean up distributed environment and finish wandb run - if dist.is_initialized(): - dist.destroy_process_group() - logger.info("Cleaned up distributed process group") - # Finish wandb run on rank 0 - if dist.get_rank() == 0: - wandb.finish() - logger.info("Wandb run finished.") - else: - wandb.finish() - logger.info("Wandb run finished.") - - -def all_reduce_grads(model, world_mesh, use_ddp): - """All reduce gradients across dp_cp if applicable.""" - cp_mesh = world_mesh["cp"] - if use_ddp: - # DDP takes care of syncing grads - mesh = cp_mesh - else: - mesh = world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp") - if dist.is_initialized() and mesh.size() > 1: - for name, param in model.named_parameters(): - if param.grad is not None: - # Workaround for cross-mesh communication limitation with DTensor gradients - if isinstance(param.grad, DTensor): - local_grad = param.grad.to_local() - # Ensure grad requires grad for inplace modification checks (might not be needed) - # local_grad = local_grad.detach().requires_grad_(True) - torch.distributed.all_reduce(local_grad, op=torch.distributed.ReduceOp.SUM, group=mesh.get_group()) - local_grad = local_grad / mesh.size() - # Assign averaged grad back - need careful handling if DTensor structure is complex - # This simple assignment might work if the grad structure matches param structure - param.grad = DTensor.from_local( - local_grad, device_mesh=param.grad.device_mesh, placements=param.grad.placements - ) - else: - # Handle regular tensors if any exist (e.g. buffers not converted to DTensor) - torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG, group=mesh.get_group()) - - -class ContextParallelCollator: - """Collator for context parallel training that splits sequences into chunks.""" - - def __init__(self, cp_mesh: Optional[DeviceMesh] = None): - self.cp_mesh = cp_mesh - - def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - batch = default_collate(batch) - if self.cp_mesh is not None and self.cp_mesh.size() > 1: - # Get sequence length from the input batch - seq_len = batch["input_ids"].shape[1] - assert seq_len % self.cp_mesh.size() == 0, ( - f"Sequence length {seq_len} must be divisible by CP size {self.cp_mesh.size()}" - ) - chunk_size = seq_len // self.cp_mesh.size() - cp_rank = self.cp_mesh.get_local_rank() - start_idx = cp_rank * chunk_size - end_idx = start_idx + chunk_size - - # Keep only the local chunk of the sequence - batch["input_ids"] = batch["input_ids"][:, start_idx:end_idx] - batch["attention_mask"] = batch["attention_mask"][:, start_idx:end_idx] - batch["labels"] = batch["labels"][:, start_idx:end_idx] - - return batch - - -class AppState(Stateful): - """Wrapper for checkpointing the Application State including model and optimizer.""" - - def __init__(self, model, optimizer=None): - self.model = model - self.optimizer = optimizer - - def state_dict(self): - model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) - return {"model": model_state_dict, "optim": optimizer_state_dict} - - def load_state_dict(self, state_dict): - set_state_dict( - self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"] - ) - - -def sanity_check_tensor_sync( - tensor: torch.Tensor, mesh: DeviceMesh, rtol: float = 1e-4, atol: float = 1e-4, not_sync: bool = False -) -> None: - """ - Verify that a tensor is synchronized (or not synchronized) across all processes in the mesh's process group. - Handles both regular tensors and DTensors. - - Args: - tensor (torch.Tensor): The tensor to check for synchronization (can be DTensor) - mesh (DeviceMesh): The device mesh containing the process group - rtol (float): Relative tolerance for comparison - atol (float): Absolute tolerance for comparison - not_sync (bool): If True, asserts that tensors are NOT synchronized. If False, asserts they are synchronized. - """ - if not dist.is_initialized() or mesh.size() == 1: - return # No need to check in non-distributed mode - - # Get the process group from the mesh - pg = mesh.get_group() - - # Convert DTensor to local tensor if needed - if hasattr(tensor, "to_local"): - local_tensor = tensor.to_local() - else: - local_tensor = tensor - - # Gather tensors from all processes - world_size = dist.get_world_size(pg) - gathered_tensors = [torch.empty_like(local_tensor) for _ in range(world_size)] - dist.all_gather(gathered_tensors, local_tensor, group=pg) - - # Compare each tensor with the first one - for i in range(1, world_size): - try: - torch.testing.assert_close(gathered_tensors[0], gathered_tensors[i], rtol=rtol, atol=atol) - except AssertionError as e: - if not_sync: - continue - # # Add detailed debugging for logit synchronization issues - # print(f"\nLogit synchronization error between rank 0 and rank {i}:") - # print(f"Tensor shape: {gathered_tensors[0].shape}") - # print(f"Number of mismatched elements: {(gathered_tensors[0] != gathered_tensors[i]).sum()}") - # print(f"Percentage of mismatched elements: {((gathered_tensors[0] != gathered_tensors[i]).sum() / gathered_tensors[0].numel() * 100):.2f}%") - - # # Find the first few mismatches - # mismatches = torch.nonzero(gathered_tensors[0] != gathered_tensors[i]) - # print("\nFirst few mismatches:") - # for idx in mismatches[:5]: - # idx = tuple(idx.tolist()) - # print(f"Index {idx}:") - # print(f"Rank 0 value: {gathered_tensors[0][idx]}") - # print(f"Rank {i} value: {gathered_tensors[i][idx]}") - # print(f"Absolute difference: {abs(gathered_tensors[0][idx] - gathered_tensors[i][idx])}") - # print(f"Relative difference: {abs(gathered_tensors[0][idx] - gathered_tensors[i][idx]) / max(abs(gathered_tensors[0][idx]), abs(gathered_tensors[i][idx]))}") - - # # Check if differences are systematic (e.g., all positive or negative) - # diff = gathered_tensors[0] - gathered_tensors[i] - # print(f"\nDifference statistics:") - # print(f"Mean difference: {diff.mean()}") - # print(f"Std difference: {diff.std()}") - # print(f"Max positive difference: {diff.max()}") - # print(f"Max negative difference: {diff.min()}") - raise e - - -def clip_grad_norm_( - parameters: Iterable[torch.Tensor], - max_norm: float, - norm_type: float = 2.0, - error_if_nonfinite: bool = False, - foreach: bool | None = None, -) -> torch.Tensor: - """ - Clip the gradient norm of an iterable of parameters. - """ - # Filter out parameters with no gradients - parameters = [p for p in parameters if p.grad is not None] - assert len(parameters) > 0, "No parameters with gradients found" - - # Calculate total norm - if norm_type == float("inf"): - total_norm = max(p.grad.detach().abs().max() for p in parameters) - else: - total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type) - - # Convert DTensor to local tensor if needed - if isinstance(total_norm, DTensor): - total_norm = total_norm.full_tensor() - - # Clip gradients - clip_coef = max_norm / (total_norm + 1e-6) - if clip_coef < 1: - for p in parameters: - p.grad.detach().mul_(clip_coef) - - return total_norm - - -def check_params_sync(model_params, original_params): - """ - Check if original_params are being updated in sync with model parameters. - - Args: - model_params: Iterator of model parameters after update - original_params: List of original parameters before DDP wrapping - """ - for mp, op in zip(model_params, original_params): - if isinstance(mp, DTensor): - mp = mp.to_local() - if isinstance(op, DTensor): - op = op.to_local() - if not torch.allclose(mp.data, op.data, rtol=0, atol=0): - raise RuntimeError(f"Parameters out of sync: model param {mp.data} != original param {op.data}") - return True - - -def get_parameters(model: nn.Module) -> Iterable[torch.Tensor]: - """ - Get all parameters from a model by iterating over its modules. - This is an alternative to model.parameters() that works with DTensor models. - - Args: - model (nn.Module): The model to get parameters from - - Returns: - Iterable[torch.Tensor]: An iterator over all parameters in the model - """ - for name, module in model._modules.items(): - # Look for parameters in module attributes - for attr_name, attr in module.__dict__.items(): - if isinstance(attr, torch.Tensor) and attr.requires_grad: - yield attr - # Recursively get parameters from submodules - for param in get_parameters(module): - yield param - - -def update_model_parameters(model: nn.Module) -> None: - """ - Update model._parameters using named_modules() to ensure all parameters are properly tracked. - - Args: - model (nn.Module): The model to update parameters for - """ - # Clear existing parameters - model._parameters = {} - - # Add parameters from named_modules - for name, module in model.named_modules(): - # Skip the root module itself - if name == "": - continue - - # Get the parameter name by removing 'module.' prefix if it exists - param_name = name.replace("module.", "") - - # Add weight and bias parameters if they exist - if hasattr(module, "weight") and module.weight is not None: - model._parameters[f"{param_name}.weight"] = module.weight - if hasattr(module, "bias") and module.bias is not None: - model._parameters[f"{param_name}.bias"] = module.bias - - -if __name__ == "__main__": - main() diff --git a/examples/pytorch/context_parallel.py b/examples/pytorch/context_parallel.py deleted file mode 100644 index 22cc75b20f..0000000000 --- a/examples/pytorch/context_parallel.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -import torch -import torch.distributed as dist -from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.tensor.experimental import context_parallel -from torch.nn.attention import SDPBackend, sdpa_kernel -from torch.nn.parallel import DistributedDataParallel as DDP - -from transformers import AutoModelForCausalLM -from transformers.loss.loss_utils import ForCausalLMLoss - - -world_size = int(os.environ.get("WORLD_SIZE", "1")) -cp_mesh = init_device_mesh("cuda", (world_size,)) -rank = torch.distributed.get_node_local_rank() - -device = "cuda" -dtype = torch.bfloat16 -sdpa_backend = SDPBackend.FLASH_ATTENTION - -# prepare inputs -batch_size = 1 -seq_len = 128 - -input_ids = torch.randint(low=8, high=64, size=(batch_size, seq_len), device=device) - -ignore_index = -100 -# When using CP, we need to use `shift_labels` -shift_labels = torch.nn.functional.pad(input_ids, (0, 1), value=ignore_index) -shift_labels = shift_labels[..., 1:].contiguous() - -position_ids = ( - torch.cumsum(torch.ones(size=input_ids.size(), dtype=input_ids.dtype, device=input_ids.device), dim=1) - 1 -) - -# sync input as they are created randomly -dist.broadcast(input_ids, src=0) -dist.broadcast(shift_labels, src=0) -dist.broadcast(position_ids, src=0) - -# model and optimizer -repo_id = "Qwen/Qwen2.5-Coder-0.5B-Instruct" -model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=dtype, device_map=device) -optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) - -model.train() -model.zero_grad() -optimizer.zero_grad() - -# For loss -vocab_size = model.config.vocab_size - -# so training could be synced -model = DDP(model, device_ids=[rank]) - -# prepare for CP -buffers = (input_ids, shift_labels, position_ids) -buffer_seq_dims = (1, 1, 1) -# `no_restore_buffers=set(buffers)` is required if `loss.backward` is outside `context_parallel`. -# no_restore_buffers = set(buffers) -no_restore_buffers = None - -# run with CP -with sdpa_kernel(sdpa_backend): - with context_parallel( - cp_mesh, - buffers=buffers, - buffer_seq_dims=buffer_seq_dims, - no_restore_buffers=no_restore_buffers, - ): - outputs = model(input_ids, shift_labels=shift_labels, position_ids=position_ids) - print(outputs.logits.shape) - - # So far we need to compute `loss` outside `model.forward` when using `shift_labels` - # loss = outputs.loss - loss = ForCausalLMLoss(logits=outputs.logits, labels=None, shift_labels=shift_labels, vocab_size=vocab_size) - - # This could be outside `context_parallel` context if `no_restore_buffers` is specified - loss.backward() - optimizer.step() diff --git a/setup.py b/setup.py index 52024f77c1..9fe5007305 100644 --- a/setup.py +++ b/setup.py @@ -125,7 +125,7 @@ _deps = [ "jaxlib>=0.4.1,<=0.4.13", "jieba", "jinja2>=3.1.0", - "kenlm", + "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5", # Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support. "keras>2.9,<2.16", "keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras. @@ -315,7 +315,7 @@ extras["audio"] = deps_list( "librosa", "pyctcdecode", "phonemizer", - "kenlm", + "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5", ) # `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead extras["speech"] = deps_list("torchaudio") + extras["audio"] diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index c01f5bb388..dc2b37a192 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -32,7 +32,7 @@ deps = { "jaxlib": "jaxlib>=0.4.1,<=0.4.13", "jieba": "jieba", "jinja2": "jinja2>=3.1.0", - "kenlm": "kenlm", + "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5": "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5", "keras": "keras>2.9,<2.16", "keras-nlp": "keras-nlp>=0.3.1,<0.14.0", "kernels": "kernels>=0.4.4,<0.5", diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 1b87a554d3..8d03c5cf79 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -142,7 +142,7 @@ except OptionalDependencyNotAvailable: else: _import_structure["tensor_parallel"] = [ "shard_and_distribute_module", - "ALL_PARALLEL_STYLES", + "SUPPORTED_TP_STYLES", "translate_to_torch_parallel_style", ] try: @@ -271,7 +271,7 @@ if TYPE_CHECKING: pass else: from .tensor_parallel import ( - ALL_PARALLEL_STYLES, + SUPPORTED_TP_STYLES, shard_and_distribute_module, translate_to_torch_parallel_style, ) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index d07a768b01..e788321b49 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -13,15 +13,11 @@ # limitations under the License. from __future__ import annotations -import operator -import os import re -from collections.abc import MutableMapping -from functools import partial, reduce -from typing import Callable, List, Optional, Tuple, Union +from functools import lru_cache, partial +from typing import List, Optional, Tuple, Union import torch -import torch.distributed as dist from torch import nn from ..utils import is_torch_greater_or_equal, logging @@ -39,56 +35,6 @@ if is_torch_greater_or_equal("2.5") and _torch_distributed_available: from torch.distributed.tensor import DTensor, Placement, Replicate, Shard -def initialize_tensor_parallelism(tp_plan, tp_size=None): - r""" - Sets up the device mesh and initilized the backend for tensor parallelism. - This function is called when the model is loaded and the TP plan is set to 'auto'. - """ - if tp_plan is None: - return None, None, None - - if not is_torch_greater_or_equal("2.5"): - raise EnvironmentError("Tensor parallel is only supported for `torch>=2.5`.") - - # Detect the accelerator on the machine. If no accelerator is available, it returns CPU. - device_type = torch._C._get_accelerator().type - if not torch.distributed.is_initialized(): - try: - rank = int(os.environ["RANK"]) - local_rank = int(os.environ["LOCAL_RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - - backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "ccl", "hpu": "hccl"} - backend = backend_map.get(device_type) - if device_type == "cpu" and int(os.environ.get("CCL_WORKER_COUNT", 0)): - backend = "ccl" - - torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size) - current_device = getattr(torch, device_type) - if device_type != "cpu": - current_device.set_device(local_rank) - - except Exception as e: - raise EnvironmentError( - "We tried to initialize torch.distributed for you, but it failed. Make " - "sure you init torch distributed in your script to use `tp_plan='auto'`." - ) from e - index = current_device.current_device() if device_type != "cpu" else None - tp_device = torch.device(device_type, index) - - # Silence output for non-primary ranks - if index is not None and index > 0: - import sys - - sys.stdout = open(os.devnull, "w") - sys.stderr = open(os.devnull, "w") - - device_map = tp_device - tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size() - device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,)) - return tp_device, device_map, device_mesh - - def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: """ Convert block count or proportions to block sizes. @@ -274,38 +220,18 @@ def repack_weights( def get_tensor_shard(param, empty_param, device_mesh, rank, dim): - """ - Generalized tensor sharding across a multi-dimensional device mesh. - - Args: - param (torch.Tensor): The tensor to shard. - empty_param (torch.Tensor): A tensor used for shape reference. - device_mesh (torch.Tensor): Shape [d_0, ..., d_n] representing the mesh. - rank (int): Global rank of the current process/device. - dim (int): Dimension along which to shard the tensor. - """ - param_dim = empty_param.dim() - if dim < 0: - dim = param_dim + dim - if dim >= param_dim: - raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}") - - # Flatten the mesh to get the total number of devices - mesh_shape = device_mesh.shape - world_size = reduce(operator.mul, mesh_shape) - - if rank >= world_size: - raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}") - - shard_size = empty_param.shape[dim] // world_size - start = rank * shard_size - end = start + shard_size - - # Construct slicing index dynamically - slice_indices = [slice(None)] * param_dim - slice_indices[dim] = slice(start, end) - - return param[tuple(slice_indices)] + if dim == 0: + size_ = empty_param.shape[0] + param = param[rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size()), ...] + elif dim == 1 or dim == -2: + size_ = empty_param.shape[-2] + param = param[..., rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size()), :] + elif dim == 2 or dim == -1: + size_ = empty_param.shape[-1] + param = param[..., rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size())] + else: + raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported") + return param def distribute_module( @@ -413,41 +339,6 @@ class IsolatedParallel(TensorParallelLayer): ) -class ReplicateParallel(TensorParallelLayer): - """ - This class is used to replicate computation in a TP layer (used in SP regions when we don't use sequence parallelism for example) - """ - - def __init__(self, *, use_dtensor=True, use_local_output=True): - super().__init__() - self.input_layouts = (Replicate(),) - self.output_layouts = (Replicate(),) - self.desired_input_layouts = (Replicate(),) - self.use_local_output = use_local_output - self.use_dtensor = use_dtensor - - @staticmethod - def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): - # TODO: figure out dynamo support for instance method and switch this to instance method - # annotate module input placements/sharding with input_layouts - input_tensor = inputs[0] - if not isinstance(input_tensor, DTensor): - input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) - - return input_tensor - - @staticmethod - def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): - return outputs.to_local() if use_local_output else outputs - - def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): - param = param[...].to(param_casting_dtype) - if to_contiguous: - param = param.contiguous() - param = DTensor.from_local(param, device_mesh, [Replicate()], run_check=False) - return param - - class ColwiseParallel(TensorParallelLayer): """ General tensor parallel layer for transformers. @@ -720,67 +611,52 @@ class SequenceParallel(TensorParallelLayer): return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) -class ParallelInterface(MutableMapping): +SUPPORTED_TP_STYLES = { + "colwise", + "rowwise", + "colwise_rep", + "rowwise_rep", + "local_colwise", + "local_rowwise", + "local", + "gather", + "local_packed_rowwise", + "sequence_parallel", +} + + +@lru_cache +def translate_to_torch_parallel_style(style: str): """ - Dict-like object keeping track of allowed attention functions. You can easily add a new attention function - with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`, - it needs to declare a new instance of this class inside the `modeling_.py`, and declare it on that instance. + In model configurations, we use a neutral type (string) to specify parallel + styles, here we translate them into torch.distributed tensor-parallel + types. """ + if not isinstance(style, str): + raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") - # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if - # a new instance is created (in order to locally override a given function) - - def __init__(self): - self._local_mapping = {} - - ParallelInterface._global_mapping = { - "colwise": ColwiseParallel(), - "rowwise": RowwiseParallel(), - "colwise_rep": ColwiseParallel(output_layouts=Replicate()), - "rowwise_rep": RowwiseParallel(input_layouts=Replicate()), - "local_colwise": ColwiseParallel(use_dtensor=False), - "local_rowwise": RowwiseParallel(use_dtensor=False), - "local": IsolatedParallel(), - "gather": GatherParallel(), - "local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False), - "sequence_parallel": SequenceParallel(), - "replicate": ReplicateParallel(), - } - - def __getitem__(self, key): - # First check if instance has a local override - if key in self._local_mapping: - return self._local_mapping[key] - return self._global_mapping[key] - - def __setitem__(self, key, value): - # Allow local update of the default functions without impacting other instances - self._local_mapping.update({key: value}) - - def __delitem__(self, key): - del self._local_mapping[key] - - def __iter__(self): - # Ensure we use all keys, with the overwritten ones on top - return iter({**self._global_mapping, **self._local_mapping}) - - def __len__(self): - return len(self._global_mapping.keys() | self._local_mapping.keys()) - - @classmethod - def register(cls, key: str, value: Callable): - cls._global_mapping.update({key: value}) - - def valid_keys(self) -> List[str]: - return list(self.keys()) - - -# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones - -if is_torch_greater_or_equal("2.5") and _torch_distributed_available: - ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface() -else: - ALL_PARALLEL_STYLES = None + if style == "colwise": + return ColwiseParallel() + elif style == "rowwise": + return RowwiseParallel() + elif style == "colwise_rep": + return ColwiseParallel(output_layouts=Replicate()) + elif style == "rowwise_rep": + return RowwiseParallel(input_layouts=Replicate()) + elif style == "local_colwise": + return ColwiseParallel(use_dtensor=False) + elif style == "local_rowwise": + return RowwiseParallel(use_dtensor=False) + elif style == "local": + return IsolatedParallel() + elif style == "gather": + return GatherParallel() + elif style == "local_packed_rowwise": + return PackedRowwiseParallel(use_dtensor=False) + elif style == "sequence_parallel": + return SequenceParallel() + else: + raise ValueError(f"Unsupported parallel style value: {style}") def convert_local_tensor_to_dtensor( @@ -846,15 +722,13 @@ def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, curr # 1. We add hooks to the layer being loaded: if current_module_plan is not None: - tp_layer = ALL_PARALLEL_STYLES[current_module_plan] + tp_layer = translate_to_torch_parallel_style(current_module_plan) try: tp_layer.prepare_module_tp(module, device_mesh) except NotImplementedError as e: print( f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}" ) - module._hf_tp_plan = current_module_plan - module.__repr__ = lambda: f"{module.__repr__()}\nTP Plan: {current_module_plan}" # 2. We add hooks to the parent module if needed if "." in layer_name: @@ -862,11 +736,9 @@ def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, curr generic_name = re.sub(r"\d+", "*", parent_layer_name) # The module itself needs hooks if module_plan := tp_plan.get(generic_name, False): - tp_layer = ALL_PARALLEL_STYLES[module_plan] + tp_layer = translate_to_torch_parallel_style(module_plan) module_to_tp_ = model.get_submodule(parent_layer_name) tp_layer.prepare_module_tp(module_to_tp_, device_mesh) - module_to_tp_._hf_tp_plan = current_module_plan - module_to_tp_.__repr__ = lambda: f"{module_to_tp_.__repr__()}\nTP Plan: {current_module_plan}" def shard_and_distribute_module( @@ -888,29 +760,28 @@ def shard_and_distribute_module( current_module_plan = _get_parameter_tp_plan(parameter_name, tp_plan) - if current_module_plan is None: - current_module_plan = "replicate" - if dist.get_rank() == 0: - logger.info(f"Tensor parallel plan for {param_name} not found, using default 'replicate' plan.") - else: - if dist.get_rank() == 0: - logger.info(f"Tensor parallel plan for {param_name}: {current_module_plan}") - # Add hooks to the module if not done yet # add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh) if not getattr(module_to_tp, "_is_hooked", False): add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh) module_to_tp._is_hooked = True - try: - tp_layer = ALL_PARALLEL_STYLES[current_module_plan] - param = tp_layer.partition_tensor( - param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh - ) - except NotImplementedError as e: - print( - f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}" - ) + if current_module_plan is not None: + try: + tp_layer = translate_to_torch_parallel_style(current_module_plan) + param = tp_layer.partition_tensor( + param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh + ) + except NotImplementedError as e: + print( + f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}" + ) + else: + # TODO log no plan modules in set + # print("No plan for", parameter_name,end ="\n") + param = param[...].to(param_casting_dtype) + if is_contiguous: + param = param.contiguous() # SUPER IMPORTANT we have to use setattr # otherwise loading is crazy slow diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ece787c7b6..b40b7cd2b3 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -62,9 +62,8 @@ from .integrations.flash_attention import flash_attention_forward from .integrations.flex_attention import flex_attention_forward from .integrations.sdpa_attention import sdpa_attention_forward from .integrations.tensor_parallel import ( - ALL_PARALLEL_STYLES, + SUPPORTED_TP_STYLES, _get_parameter_tp_plan, - initialize_tensor_parallelism, repack_weights, replace_state_dict_local_with_dtensor, shard_and_distribute_module, @@ -798,7 +797,7 @@ def _load_state_dict_into_meta_model( param_name, casting_dtype, to_contiguous, - device_mesh.get_local_rank(), + int(os.environ["RANK"]), # the rank device_mesh, ) else: @@ -1965,9 +1964,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi if self._tp_plan is not None and is_torch_greater_or_equal("2.3"): for _, v in self._tp_plan.items(): - if v not in ALL_PARALLEL_STYLES: + if v not in SUPPORTED_TP_STYLES: raise ValueError( - f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}" + f"Unsupported tensor parallel style {v}. Supported styles are {SUPPORTED_TP_STYLES}" ) def dequantize(self): @@ -3560,7 +3559,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh) if safe_serialization: - # TODO: fix safe_serialization for tied weights # Safetensors does not allow tensor aliasing. # We're going to remove aliases before saving ptrs = collections.defaultdict(list) @@ -4042,8 +4040,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi `torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations. tp_size (`str`, *optional*): A torch tensor parallel degree. If not provided would default to world size. - device_mesh (`torch.distributed.DeviceMesh`, *optional*): - A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now. offload_folder (`str` or `os.PathLike`, *optional*): If the `device_map` contains any value `"disk"`, the folder where we will offload weights. offload_state_dict (`bool`, *optional*): @@ -4141,7 +4137,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi gguf_file = kwargs.pop("gguf_file", None) tp_plan = kwargs.pop("tp_plan", None) tp_size = kwargs.pop("tp_size", None) - device_mesh = kwargs.pop("device_mesh", None) trust_remote_code = kwargs.pop("trust_remote_code", None) # Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model @@ -4177,13 +4172,59 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple # `device_map` pointing to the correct device - if device_mesh is None: - tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None) - else: - # TODO: make device_mesh support multiple dimensions - if device_mesh.ndim == 1: - raise ValueError("device_mesh must be 1 dimensional and will be used for TP") - device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"])) + device_mesh = None + if tp_plan is not None: + if not is_torch_greater_or_equal("2.5"): + raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.") + + # Detect the accelerator on the machine. If no accelerator is available, it returns CPU. + device_type = torch._C._get_accelerator().type + + if not torch.distributed.is_initialized(): + try: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + if device_type == "cuda": + torch.distributed.init_process_group( + "nccl", rank=rank, world_size=world_size, init_method="env://" + ) + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + elif device_type == "cpu": + cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo" + torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size) + elif device_type == "xpu": + torch.distributed.init_process_group("ccl", rank=rank, world_size=world_size) + torch.xpu.set_device(int(os.environ["LOCAL_RANK"])) + elif device_type == "hpu": + torch.distributed.init_process_group("hccl", rank=rank, world_size=world_size) + torch.hpu.set_device(int(os.environ["LOCAL_RANK"])) + + except Exception as e: + raise EnvironmentError( + "We tried to initialize torch.distributed for you, but it failed, make" + "sure you init torch distributed in your script to use `tp_plan='auto'`" + ) from e + + # Get device with index assuming equal number of devices per host + if device_type == "xpu": + index = torch.xpu.current_device() + elif device_type == "hpu": + index = torch.hpu.current_device() + else: + index = None if device_type == "cpu" else torch.cuda.current_device() + tp_device = torch.device(device_type, index) + + if index is not None and index > 0: + import sys + + sys.stdout = open(os.devnull, "w") + sys.stderr = open(os.devnull, "w") + # This is the easiest way to dispatch to the current process device + device_map = tp_device + + # Assuming sharding the model onto the world when tp_size not provided + tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size() + device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,)) if use_auth_token is not None: warnings.warn( @@ -5101,7 +5142,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi name, casting_dtype, to_contiguous, - device_mesh.get_local_rank(), + os.environ["RANK"], device_mesh, )