Revert "remove dtensors, not explicit (#39840)" (#39912)

* Revert "remove dtensors, not explicit (#39840)"
This did not work with generation (lm_head needs extra care!)
This reverts commit 6dfd561d9c.

* update

* style?
This commit is contained in:
Arthur
2025-08-05 15:12:14 +02:00
committed by GitHub
parent 2589a52c5c
commit 20ce210ab7
4 changed files with 77 additions and 73 deletions

View File

@@ -130,6 +130,7 @@ doctest.DocTestParser = HfDocTestParser
if is_torch_available(): if is_torch_available():
import torch import torch
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
# We set it to `False` for CI. See https://github.com/pytorch/pytorch/issues/157274#issuecomment-3090791615 # We set it to `False` for CI. See https://github.com/pytorch/pytorch/issues/157274#issuecomment-3090791615
torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False

View File

@@ -150,7 +150,6 @@ str_to_torch_dtype = {
"F64": torch.float64, "F64": torch.float64,
"I64": torch.int64, "I64": torch.int64,
"F8_E4M3": torch.float8_e4m3fn, "F8_E4M3": torch.float8_e4m3fn,
"F8_E5M2": torch.float8_e5m2,
} }
@@ -526,43 +525,6 @@ class ReplicateParallel(TensorParallelLayer):
return param return param
class ReduceFromModelParallelRegion(torch.autograd.Function):
"""
All-reduce in forward pass, identity in backward pass.
This is the `g` function in the paper: https://arxiv.org/abs/1909.08053
"""
@staticmethod
def forward(ctx, x, device_mesh):
if device_mesh.size() == 1:
return x
dist.all_reduce(x, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output
class CopyToModelParallelRegion(torch.autograd.Function):
"""
Copy in forward pass, all-reduce in backward pass.
This is the `f` function in the paper: https://arxiv.org/abs/1909.08053
"""
@staticmethod
def forward(ctx, x, device_mesh):
ctx.device_mesh = device_mesh
return x
@staticmethod
def backward(ctx, grad_output):
if ctx.device_mesh.size() == 1:
return grad_output
dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=ctx.device_mesh.get_group())
return grad_output
class ColwiseParallel(TensorParallelLayer): class ColwiseParallel(TensorParallelLayer):
""" """
General tensor parallel layer for transformers. General tensor parallel layer for transformers.
@@ -585,8 +547,15 @@ class ColwiseParallel(TensorParallelLayer):
@staticmethod @staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
# TODO: figure out dynamo support for instance method and switch this to instance method
# annotate module input placements/sharding with input_layouts # annotate module input placements/sharding with input_layouts
input_tensor = inputs[0] input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
# transform the input layouts to the desired layouts of ColwiseParallel
if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
return input_tensor return input_tensor
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
@@ -595,19 +564,41 @@ class ColwiseParallel(TensorParallelLayer):
# weight would become Shard(1) # weight would become Shard(1)
if param_type == "bias": if param_type == "bias":
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1) parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
shard = [Shard(-1)]
else: else:
shard = [Shard(-2)]
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2) parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2)
parameter = parameter.to(param_casting_dtype) parameter = parameter.to(param_casting_dtype)
if to_contiguous: if to_contiguous:
parameter = parameter.contiguous() parameter = parameter.contiguous()
if self.use_dtensor:
parameter = DTensor.from_local(
parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride()
)
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
@staticmethod @staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
outputs = CopyToModelParallelRegion.apply(outputs, device_mesh) # outputs is a shard on last dimension DTensor, i.e. Shard(-1)
return outputs if outputs.placements != output_layouts:
outputs = outputs.redistribute(placements=output_layouts, async_op=False)
# back to local tensor
return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs
class PackedColwiseParallel(ColwiseParallel):
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
parameter = get_packed_weights(param, empty_param, device_mesh, rank, -2)
parameter = parameter.to(param_casting_dtype)
if to_contiguous:
parameter = parameter.contiguous()
if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, [Shard(-2)], run_check=False)
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
class RowwiseParallel(TensorParallelLayer): class RowwiseParallel(TensorParallelLayer):
@@ -644,15 +635,23 @@ class RowwiseParallel(TensorParallelLayer):
self.use_dtensor = use_dtensor self.use_dtensor = use_dtensor
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
if param_type == "bias": # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
parameter = param[:] # means Rowwise as nn.Linear is input * weight^T + bias, where
else: # weight would become Shard(0)
if param_type != "bias":
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1) parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
shard = [Shard(-1)]
else:
shard = [Replicate()]
parameter = param[:]
parameter = parameter.to(param_casting_dtype) parameter = parameter.to(param_casting_dtype)
if to_contiguous: if to_contiguous:
parameter = parameter.contiguous() parameter = parameter.contiguous()
if self.use_dtensor:
parameter = DTensor.from_local(
parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride()
)
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
@staticmethod @staticmethod
@@ -662,13 +661,24 @@ class RowwiseParallel(TensorParallelLayer):
mod.bias = None mod.bias = None
input_tensor = inputs[0] input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
return input_tensor return input_tensor
@staticmethod @staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
outputs = ReduceFromModelParallelRegion.apply(outputs, device_mesh) # Rowwise sharding produces partial output, depending on output layouts:
# 1. to replicate -> allreduce
# 2. to shard -> reduce_scatter
if outputs.placements != output_layouts:
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
outputs = outputs.to_local() # otherwise the `+=` op will gather
if hasattr(mod, "_bias"): if hasattr(mod, "_bias"):
outputs += mod._bias outputs += mod._bias
# back to local tensor if use_local_output is True
return outputs return outputs
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
@@ -694,21 +704,6 @@ class RowwiseParallel(TensorParallelLayer):
) )
class PackedColwiseParallel(ColwiseParallel):
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# NOTE(3outeille): need to be deprecated as no longer using dtensors
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
parameter = get_packed_weights(param, empty_param, device_mesh, rank, -2)
parameter = parameter.to(param_casting_dtype)
if to_contiguous:
parameter = parameter.contiguous()
if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, [Shard(-2)], run_check=False)
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
class PackedRowwiseParallel(RowwiseParallel): class PackedRowwiseParallel(RowwiseParallel):
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only) # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)

View File

@@ -4087,16 +4087,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
for shard_file, tensors in filename_to_tensors: for shard_file, tensors in filename_to_tensors:
shard = {} shard = {}
for tensor in tensors: for tensor in tensors:
if _is_dtensor_available and getattr(self, "_device_mesh", None) is not None: if _is_dtensor_available and isinstance(state_dict[tensor], DTensor):
plan = _get_parameter_tp_plan(tensor, self._tp_plan) full_tensor = state_dict[tensor].full_tensor()
full_tensor = state_dict[tensor] # to get the correctly ordered tensor we need to repack if packed
if isinstance(state_dict[tensor], DTensor):
full_tensor = full_tensor.full_tensor()
elif plan is not None:
shard_dim = -1 if "rowwise" in plan else 0
gather_list = [torch.empty_like(full_tensor) for _ in range(self._device_mesh.size())]
torch.distributed.all_gather(gather_list, full_tensor)
full_tensor = torch.cat(gather_list, dim=shard_dim)
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",): if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2) full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly

View File

@@ -101,6 +101,14 @@ class TestTensorParallel(TestCasePlus):
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto") model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto")
torch.distributed.barrier() torch.distributed.barrier()
has_dtensor = 0
for name, parameter in model.named_parameters():
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
has_dtensor = 1
break
assert has_dtensor == 1, "TP model must has DTensor"
tokenizer = AutoTokenizer.from_pretrained(model_id, legacy=False) tokenizer = AutoTokenizer.from_pretrained(model_id, legacy=False)
prompt = "Can I help" prompt = "Can I help"
@@ -110,8 +118,7 @@ class TestTensorParallel(TestCasePlus):
next_token_logits = outputs[0][:, -1, :] next_token_logits = outputs[0][:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1) next_token = torch.argmax(next_token_logits, dim=-1)
response = tokenizer.decode(next_token) response = tokenizer.decode(next_token)
print(response) assert response == "with"
# assert response == "with"
torch.distributed.barrier() torch.distributed.barrier()
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
@@ -136,6 +143,14 @@ class TestTensorParallel(TestCasePlus):
model.forward = torch.compile(model.forward) model.forward = torch.compile(model.forward)
has_dtensor = 0
for name, parameter in model.named_parameters():
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
has_dtensor = 1
break
assert has_dtensor == 1, "TP model must has DTensor"
tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Can I help" prompt = "Can I help"