From 20ce210ab77b2d18d9fb34a42b913e2e68feba7f Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 5 Aug 2025 15:12:14 +0200 Subject: [PATCH] Revert "remove dtensors, not explicit (#39840)" (#39912) * Revert "remove dtensors, not explicit (#39840)" This did not work with generation (lm_head needs extra care!) This reverts commit 6dfd561d9cd722dfc09f702355518c6d09b9b4e3. * update * style? --- conftest.py | 1 + .../integrations/tensor_parallel.py | 117 +++++++++--------- src/transformers/modeling_utils.py | 13 +- tests/tensor_parallel/test_tensor_parallel.py | 19 ++- 4 files changed, 77 insertions(+), 73 deletions(-) diff --git a/conftest.py b/conftest.py index 1abe8fb4a3..2134dceb84 100644 --- a/conftest.py +++ b/conftest.py @@ -130,6 +130,7 @@ doctest.DocTestParser = HfDocTestParser if is_torch_available(): import torch + # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. # We set it to `False` for CI. See https://github.com/pytorch/pytorch/issues/157274#issuecomment-3090791615 torch.backends.cudnn.allow_tf32 = False diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index d0fcfc1fff..353cc1d081 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -150,7 +150,6 @@ str_to_torch_dtype = { "F64": torch.float64, "I64": torch.int64, "F8_E4M3": torch.float8_e4m3fn, - "F8_E5M2": torch.float8_e5m2, } @@ -526,43 +525,6 @@ class ReplicateParallel(TensorParallelLayer): return param -class ReduceFromModelParallelRegion(torch.autograd.Function): - """ - All-reduce in forward pass, identity in backward pass. - This is the `g` function in the paper: https://arxiv.org/abs/1909.08053 - """ - - @staticmethod - def forward(ctx, x, device_mesh): - if device_mesh.size() == 1: - return x - dist.all_reduce(x, op=dist.ReduceOp.SUM, group=device_mesh.get_group()) - return x - - @staticmethod - def backward(ctx, grad_output): - return grad_output - - -class CopyToModelParallelRegion(torch.autograd.Function): - """ - Copy in forward pass, all-reduce in backward pass. - This is the `f` function in the paper: https://arxiv.org/abs/1909.08053 - """ - - @staticmethod - def forward(ctx, x, device_mesh): - ctx.device_mesh = device_mesh - return x - - @staticmethod - def backward(ctx, grad_output): - if ctx.device_mesh.size() == 1: - return grad_output - dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=ctx.device_mesh.get_group()) - return grad_output - - class ColwiseParallel(TensorParallelLayer): """ General tensor parallel layer for transformers. @@ -585,8 +547,15 @@ class ColwiseParallel(TensorParallelLayer): @staticmethod def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + # TODO: figure out dynamo support for instance method and switch this to instance method # annotate module input placements/sharding with input_layouts input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) + + # transform the input layouts to the desired layouts of ColwiseParallel + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False) return input_tensor def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): @@ -595,19 +564,41 @@ class ColwiseParallel(TensorParallelLayer): # weight would become Shard(1) if param_type == "bias": parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1) + shard = [Shard(-1)] else: + shard = [Shard(-2)] parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2) parameter = parameter.to(param_casting_dtype) if to_contiguous: parameter = parameter.contiguous() - + if self.use_dtensor: + parameter = DTensor.from_local( + parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride() + ) return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): - outputs = CopyToModelParallelRegion.apply(outputs, device_mesh) - return outputs + # outputs is a shard on last dimension DTensor, i.e. Shard(-1) + if outputs.placements != output_layouts: + outputs = outputs.redistribute(placements=output_layouts, async_op=False) + # back to local tensor + return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs + + +class PackedColwiseParallel(ColwiseParallel): + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): + # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only) + # means Colwise as Linear is input * weight^T + bias, where + # weight would become Shard(1) + parameter = get_packed_weights(param, empty_param, device_mesh, rank, -2) + parameter = parameter.to(param_casting_dtype) + if to_contiguous: + parameter = parameter.contiguous() + if self.use_dtensor: + parameter = DTensor.from_local(parameter, device_mesh, [Shard(-2)], run_check=False) + return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) class RowwiseParallel(TensorParallelLayer): @@ -644,15 +635,23 @@ class RowwiseParallel(TensorParallelLayer): self.use_dtensor = use_dtensor def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): - if param_type == "bias": - parameter = param[:] - else: + # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1) + # means Rowwise as nn.Linear is input * weight^T + bias, where + # weight would become Shard(0) + if param_type != "bias": parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1) + shard = [Shard(-1)] + else: + shard = [Replicate()] + parameter = param[:] parameter = parameter.to(param_casting_dtype) if to_contiguous: parameter = parameter.contiguous() - + if self.use_dtensor: + parameter = DTensor.from_local( + parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride() + ) return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) @staticmethod @@ -662,13 +661,24 @@ class RowwiseParallel(TensorParallelLayer): mod.bias = None input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) + + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) return input_tensor @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): - outputs = ReduceFromModelParallelRegion.apply(outputs, device_mesh) + # Rowwise sharding produces partial output, depending on output layouts: + # 1. to replicate -> allreduce + # 2. to shard -> reduce_scatter + if outputs.placements != output_layouts: + outputs = outputs.redistribute(placements=output_layouts, async_op=True) + outputs = outputs.to_local() # otherwise the `+=` op will gather if hasattr(mod, "_bias"): outputs += mod._bias + # back to local tensor if use_local_output is True return outputs def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: @@ -694,21 +704,6 @@ class RowwiseParallel(TensorParallelLayer): ) -class PackedColwiseParallel(ColwiseParallel): - def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): - # NOTE(3outeille): need to be deprecated as no longer using dtensors - # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only) - # means Colwise as Linear is input * weight^T + bias, where - # weight would become Shard(1) - parameter = get_packed_weights(param, empty_param, device_mesh, rank, -2) - parameter = parameter.to(param_casting_dtype) - if to_contiguous: - parameter = parameter.contiguous() - if self.use_dtensor: - parameter = DTensor.from_local(parameter, device_mesh, [Shard(-2)], run_check=False) - return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) - - class PackedRowwiseParallel(RowwiseParallel): def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6e86cb1002..03e9cf5314 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4087,16 +4087,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH for shard_file, tensors in filename_to_tensors: shard = {} for tensor in tensors: - if _is_dtensor_available and getattr(self, "_device_mesh", None) is not None: - plan = _get_parameter_tp_plan(tensor, self._tp_plan) - full_tensor = state_dict[tensor] - if isinstance(state_dict[tensor], DTensor): - full_tensor = full_tensor.full_tensor() - elif plan is not None: - shard_dim = -1 if "rowwise" in plan else 0 - gather_list = [torch.empty_like(full_tensor) for _ in range(self._device_mesh.size())] - torch.distributed.all_gather(gather_list, full_tensor) - full_tensor = torch.cat(gather_list, dim=shard_dim) + if _is_dtensor_available and isinstance(state_dict[tensor], DTensor): + full_tensor = state_dict[tensor].full_tensor() + # to get the correctly ordered tensor we need to repack if packed if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",): full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2) shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly diff --git a/tests/tensor_parallel/test_tensor_parallel.py b/tests/tensor_parallel/test_tensor_parallel.py index a9a9f05b87..1904fc8bd1 100644 --- a/tests/tensor_parallel/test_tensor_parallel.py +++ b/tests/tensor_parallel/test_tensor_parallel.py @@ -101,6 +101,14 @@ class TestTensorParallel(TestCasePlus): model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto") torch.distributed.barrier() + has_dtensor = 0 + for name, parameter in model.named_parameters(): + if isinstance(parameter.data, torch.distributed.tensor.DTensor): + has_dtensor = 1 + break + + assert has_dtensor == 1, "TP model must has DTensor" + tokenizer = AutoTokenizer.from_pretrained(model_id, legacy=False) prompt = "Can I help" @@ -110,8 +118,7 @@ class TestTensorParallel(TestCasePlus): next_token_logits = outputs[0][:, -1, :] next_token = torch.argmax(next_token_logits, dim=-1) response = tokenizer.decode(next_token) - print(response) - # assert response == "with" + assert response == "with" torch.distributed.barrier() torch.distributed.destroy_process_group() @@ -136,6 +143,14 @@ class TestTensorParallel(TestCasePlus): model.forward = torch.compile(model.forward) + has_dtensor = 0 + for name, parameter in model.named_parameters(): + if isinstance(parameter.data, torch.distributed.tensor.DTensor): + has_dtensor = 1 + break + + assert has_dtensor == 1, "TP model must has DTensor" + tokenizer = AutoTokenizer.from_pretrained(model_id) prompt = "Can I help"