Feat: save_pretrained for tensor parallel (and other parallelisms) models (#37919)

* tmp: initial save pretrained with dtensors

* Feat: add correctness tests

* Refactor: version checks

* Temp: 1:1 checkpoint llama4

* refactor

* Tests

* Feat: works

* Style

* Feat: version checks + minor fixes

* Style

* Fix: version checks in tests

* Feat: move more stuff into tensor_parallel.py
This commit is contained in:
Matej Sirovatka
2025-05-19 20:16:21 +02:00
committed by GitHub
parent 9ecee14378
commit 46a4b7c909
7 changed files with 271 additions and 12 deletions

View File

@@ -63,6 +63,9 @@ from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.tensor_parallel import (
SUPPORTED_TP_STYLES,
_get_parameter_tp_plan,
repack_weights,
replace_state_dict_local_with_dtensor,
shard_and_distribute_module,
verify_tp_plan,
)
@@ -123,6 +126,7 @@ from .utils import (
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
from .utils.import_utils import (
ENV_VARS_TRUE_VALUES,
is_huggingface_hub_greater_or_equal,
is_sagemaker_mp_enabled,
is_torch_fx_proxy,
is_torchdynamo_compiling,
@@ -168,6 +172,9 @@ _is_quantized = False
_is_ds_init_called = False
_torch_distributed_available = torch.distributed.is_available()
if _torch_distributed_available and is_torch_greater_or_equal("2.5"):
from torch.distributed.tensor import DTensor
def is_fsdp_enabled():
return (
@@ -3413,6 +3420,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
if safe_serialization and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
# we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one
if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"):
raise ImportError(
"Saving a model with tensor parallelism requires `huggingface_hub` version 0.31.4 or higher."
)
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
@@ -3540,6 +3553,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Rename state_dict keys before saving to file. Do nothing unless overridden in a particular model.
# (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm)
state_dict = self._fix_state_dict_keys_on_save(state_dict)
# If model was sharded, we cannot properly determine sizes of tensors that `local_*` strategy was used,
# therefore we replace them with DTensors that are equivalently sharded
if self._tp_size is not None:
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
if safe_serialization:
# Safetensors does not allow tensor aliasing.
@@ -3548,7 +3565,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
for name, tensor in state_dict.items():
# Sometimes in the state_dict we have non-tensor objects.
# e.g. in bitsandbytes we have some `str` objects in the state_dict
if isinstance(tensor, torch.Tensor):
if isinstance(tensor, torch.Tensor) or isinstance(tensor, DTensor):
ptrs[id_tensor_storage(tensor)].append(name)
else:
# In the non-tensor case, fall back to the pointer of the object itself
@@ -3658,7 +3675,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
for shard_file, tensors in filename_to_tensors:
shard = {}
for tensor in tensors:
shard[tensor] = state_dict[tensor].contiguous()
if isinstance(state_dict[tensor], DTensor):
full_tensor = state_dict[tensor].full_tensor()
# to get the correctly ordered tensor we need to repack if packed
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly
else:
shard[tensor] = state_dict[tensor].contiguous()
# delete reference, see https://github.com/huggingface/transformers/pull/34890
del state_dict[tensor]
@@ -4606,6 +4630,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# record tp degree the model sharded to
model._tp_size = tp_size
model._device_mesh = device_mesh
# make sure token embedding weights are still tied if needed
model.tie_weights()