remove dtensors, not explicit (#39840)
* remove dtensors, not explicit Co-authored-by: 3outeille <3outeille@users.noreply.github.com> * style * fix test * update * as we broke saving try to fix * output layouts should exit * nit * devicemesh exists if it was distributed * use _device_mesh of self * update * lol * fix * nit * update * fix! * this??? * grumble grumble * ? * fuck me --------- Co-authored-by: 3outeille <3outeille@users.noreply.github.com>
This commit is contained in:
@@ -4087,9 +4087,16 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
for shard_file, tensors in filename_to_tensors:
|
||||
shard = {}
|
||||
for tensor in tensors:
|
||||
if _is_dtensor_available and isinstance(state_dict[tensor], DTensor):
|
||||
full_tensor = state_dict[tensor].full_tensor()
|
||||
# to get the correctly ordered tensor we need to repack if packed
|
||||
if _is_dtensor_available and getattr(self, "_device_mesh", None) is not None:
|
||||
plan = _get_parameter_tp_plan(tensor, self._tp_plan)
|
||||
full_tensor = state_dict[tensor]
|
||||
if isinstance(state_dict[tensor], DTensor):
|
||||
full_tensor = full_tensor.full_tensor()
|
||||
elif plan is not None:
|
||||
shard_dim = -1 if "rowwise" in plan else 0
|
||||
gather_list = [torch.empty_like(full_tensor) for _ in range(self._device_mesh.size())]
|
||||
torch.distributed.all_gather(gather_list, full_tensor)
|
||||
full_tensor = torch.cat(gather_list, dim=shard_dim)
|
||||
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
|
||||
full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
|
||||
shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly
|
||||
|
||||
Reference in New Issue
Block a user