Feat: save_pretrained for tensor parallel (and other parallelisms) models (#37919)
* tmp: initial save pretrained with dtensors * Feat: add correctness tests * Refactor: version checks * Temp: 1:1 checkpoint llama4 * refactor * Tests * Feat: works * Style * Feat: version checks + minor fixes * Style * Fix: version checks in tests * Feat: move more stuff into tensor_parallel.py
This commit is contained in:
@@ -63,6 +63,9 @@ from .integrations.flex_attention import flex_attention_forward
|
||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||
from .integrations.tensor_parallel import (
|
||||
SUPPORTED_TP_STYLES,
|
||||
_get_parameter_tp_plan,
|
||||
repack_weights,
|
||||
replace_state_dict_local_with_dtensor,
|
||||
shard_and_distribute_module,
|
||||
verify_tp_plan,
|
||||
)
|
||||
@@ -123,6 +126,7 @@ from .utils import (
|
||||
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
|
||||
from .utils.import_utils import (
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
is_huggingface_hub_greater_or_equal,
|
||||
is_sagemaker_mp_enabled,
|
||||
is_torch_fx_proxy,
|
||||
is_torchdynamo_compiling,
|
||||
@@ -168,6 +172,9 @@ _is_quantized = False
|
||||
_is_ds_init_called = False
|
||||
_torch_distributed_available = torch.distributed.is_available()
|
||||
|
||||
if _torch_distributed_available and is_torch_greater_or_equal("2.5"):
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
|
||||
def is_fsdp_enabled():
|
||||
return (
|
||||
@@ -3413,6 +3420,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
if safe_serialization and not is_safetensors_available():
|
||||
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
|
||||
|
||||
# we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one
|
||||
if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"):
|
||||
raise ImportError(
|
||||
"Saving a model with tensor parallelism requires `huggingface_hub` version 0.31.4 or higher."
|
||||
)
|
||||
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
@@ -3540,6 +3553,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# Rename state_dict keys before saving to file. Do nothing unless overridden in a particular model.
|
||||
# (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm)
|
||||
state_dict = self._fix_state_dict_keys_on_save(state_dict)
|
||||
# If model was sharded, we cannot properly determine sizes of tensors that `local_*` strategy was used,
|
||||
# therefore we replace them with DTensors that are equivalently sharded
|
||||
if self._tp_size is not None:
|
||||
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
|
||||
|
||||
if safe_serialization:
|
||||
# Safetensors does not allow tensor aliasing.
|
||||
@@ -3548,7 +3565,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
for name, tensor in state_dict.items():
|
||||
# Sometimes in the state_dict we have non-tensor objects.
|
||||
# e.g. in bitsandbytes we have some `str` objects in the state_dict
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
if isinstance(tensor, torch.Tensor) or isinstance(tensor, DTensor):
|
||||
ptrs[id_tensor_storage(tensor)].append(name)
|
||||
else:
|
||||
# In the non-tensor case, fall back to the pointer of the object itself
|
||||
@@ -3658,7 +3675,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
for shard_file, tensors in filename_to_tensors:
|
||||
shard = {}
|
||||
for tensor in tensors:
|
||||
shard[tensor] = state_dict[tensor].contiguous()
|
||||
if isinstance(state_dict[tensor], DTensor):
|
||||
full_tensor = state_dict[tensor].full_tensor()
|
||||
# to get the correctly ordered tensor we need to repack if packed
|
||||
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
|
||||
full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
|
||||
shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly
|
||||
else:
|
||||
shard[tensor] = state_dict[tensor].contiguous()
|
||||
# delete reference, see https://github.com/huggingface/transformers/pull/34890
|
||||
del state_dict[tensor]
|
||||
|
||||
@@ -4606,6 +4630,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
# record tp degree the model sharded to
|
||||
model._tp_size = tp_size
|
||||
model._device_mesh = device_mesh
|
||||
|
||||
# make sure token embedding weights are still tied if needed
|
||||
model.tie_weights()
|
||||
|
||||
Reference in New Issue
Block a user