remove dtensors, not explicit (#39840)

* remove dtensors, not explicit

Co-authored-by: 3outeille <3outeille@users.noreply.github.com>

* style

* fix test

* update

* as we broke saving try to fix

* output layouts should exit

* nit

* devicemesh exists if it was distributed

* use _device_mesh of self

* update

* lol

* fix

* nit

* update

* fix!

* this???

* grumble grumble

* ?

* fuck me

---------

Co-authored-by: 3outeille <3outeille@users.noreply.github.com>
This commit is contained in:
Arthur
2025-08-01 22:02:47 +02:00
committed by GitHub
parent b727c2b20e
commit 6dfd561d9c
3 changed files with 74 additions and 76 deletions

View File

@@ -4087,9 +4087,16 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
for shard_file, tensors in filename_to_tensors:
shard = {}
for tensor in tensors:
if _is_dtensor_available and isinstance(state_dict[tensor], DTensor):
full_tensor = state_dict[tensor].full_tensor()
# to get the correctly ordered tensor we need to repack if packed
if _is_dtensor_available and getattr(self, "_device_mesh", None) is not None:
plan = _get_parameter_tp_plan(tensor, self._tp_plan)
full_tensor = state_dict[tensor]
if isinstance(state_dict[tensor], DTensor):
full_tensor = full_tensor.full_tensor()
elif plan is not None:
shard_dim = -1 if "rowwise" in plan else 0
gather_list = [torch.empty_like(full_tensor) for _ in range(self._device_mesh.size())]
torch.distributed.all_gather(gather_list, full_tensor)
full_tensor = torch.cat(gather_list, dim=shard_dim)
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly