Feat: save_pretrained for tensor parallel (and other parallelisms) models (#37919)

* tmp: initial save pretrained with dtensors

* Feat: add correctness tests

* Refactor: version checks

* Temp: 1:1 checkpoint llama4

* refactor

* Tests

* Feat: works

* Style

* Feat: version checks + minor fixes

* Style

* Fix: version checks in tests

* Feat: move more stuff into tensor_parallel.py
This commit is contained in:
Matej Sirovatka
2025-05-19 20:16:21 +02:00
committed by GitHub
parent 9ecee14378
commit 46a4b7c909
7 changed files with 271 additions and 12 deletions

View File

@@ -61,6 +61,22 @@ def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> Li
return [single_size] * blocks return [single_size] * blocks
def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str]) -> Optional[str]:
"""
Get the TP style for a parameter from the TP plan.
The TP plan is a dictionary that maps parameter names to TP styles.
The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight").
"""
generic_param_name = re.sub(r"\d+", "*", parameter_name)
if generic_param_name in tp_plan:
return tp_plan[generic_param_name]
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
return tp_plan[generic_param_name.rsplit(".", 1)[0]]
else:
return None
str_to_torch_dtype = { str_to_torch_dtype = {
"BOOL": torch.bool, "BOOL": torch.bool,
"U8": torch.uint8, "U8": torch.uint8,
@@ -138,6 +154,71 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
return tensor.to(str_to_torch_dtype[slice_dtype]) return tensor.to(str_to_torch_dtype[slice_dtype])
def repack_weights(
packed_parameter: torch.Tensor,
sharded_dim: int, # The dimension index in the global tensor that was sharded
world_size: int,
num_blocks: int = 2,
) -> torch.Tensor:
"""
Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format.
For example, if a weight was packed (e.g., gate_proj and up_proj) and then sharded,
DTensor.full_tensor() might produce an interleaved layout like [G0, U0, G1, U1, ...]
along the sharded dimension. This function reorders it to [G0, G1, ..., U0, U1, ...].
This is an inverse operation to get_packed_weights.
Args:
reconstructed_tensor: The tensor reconstructed from DTensor (e.g., via .full_tensor().contiguous()).
sharded_dim: The dimension index in the reconstructed_tensor that was originally sharded.
world_size: The tensor parallel world size.
num_packed_projs: The number of projections that were packed together (e.g., 2 for gate_up_proj).
Returns:
The reordered tensor in canonical packed format.
"""
if num_blocks != 2:
raise ValueError(
"Num blocks different from 2 is not supported yet. This is most likely a bug in your implementation as we only pack gate and up projections together."
)
actual_sharded_dim = sharded_dim if sharded_dim >= 0 else sharded_dim + packed_parameter.ndim
total_size_on_sharded_dim = packed_parameter.shape[actual_sharded_dim]
original_block_size_on_dim = total_size_on_sharded_dim // num_blocks
shard_chunk_size = original_block_size_on_dim // world_size
prefix_shape = packed_parameter.shape[:actual_sharded_dim]
suffix_shape = packed_parameter.shape[actual_sharded_dim + 1 :]
tensor_view = packed_parameter.view(
*prefix_shape,
world_size,
num_blocks,
shard_chunk_size,
*suffix_shape,
)
# Permute to bring num_packed_projs first, then world_size, then shard_chunk_size
# This groups all chunks of G together, then all chunks of U together.
# Target order of these middle dimensions: (num_packed_projs, world_size, shard_chunk_size)
# Current order of view's middle dimensions: (world_size, num_packed_projs, shard_chunk_size)
# Absolute indices of the dimensions to be permuted (world_size, num_packed_projs)
axis_ws_abs = len(prefix_shape)
axis_npp_abs = len(prefix_shape) + 1
permute_order = list(range(tensor_view.ndim))
permute_order[axis_ws_abs], permute_order[axis_npp_abs] = permute_order[axis_npp_abs], permute_order[axis_ws_abs]
tensor_permuted = tensor_view.permute(*permute_order)
# Reshape back to the original tensor's ndim, with the sharded dimension now correctly ordered as [G_all, U_all].
# The final shape should be the same as reconstructed_tensor.
final_ordered_tensor = tensor_permuted.reshape_as(packed_parameter)
return final_ordered_tensor
def get_tensor_shard(param, empty_param, device_mesh, rank, dim): def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
if dim == 0: if dim == 0:
size_ = empty_param.shape[0] size_ = empty_param.shape[0]
@@ -578,6 +659,49 @@ def translate_to_torch_parallel_style(style: str):
raise ValueError(f"Unsupported parallel style value: {style}") raise ValueError(f"Unsupported parallel style value: {style}")
def convert_local_tensor_to_dtensor(
parameter: torch.Tensor, parameter_name: str, device_mesh, tp_plan: dict[str, str]
) -> DTensor:
"""
Converts a local variant of weights to a DTensor with corresponding placements. Shouldn't be done ever except of before saving the model.
"""
_, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
tp_style = _get_parameter_tp_plan(parameter_name, tp_plan)
if not tp_style:
return parameter
if tp_style not in ["local_packed_rowwise", "local_rowwise", "local_colwise"]:
return parameter
# TODO: this logic should be wrapped in a function, this is copied from corresponding tp classes.
if tp_style == "local_packed_rowwise":
placements = [Shard(-1)]
elif tp_style == "local_rowwise":
if param_type == "bias":
placements = [Replicate()]
else:
placements = [Shard(-1)]
elif tp_style == "local_colwise":
if param_type == "bias":
placements = [Shard(-1)]
else:
placements = [Shard(-2)]
return DTensor.from_local(parameter, device_mesh, placements, run_check=False)
def replace_state_dict_local_with_dtensor(
state_dict: dict[str, torch.Tensor],
tp_plan: dict[str, str],
device_mesh,
) -> dict[str, torch.Tensor]:
"""
Replaces all tensors that were sharded with `local_*` strategy with DTensor to make determining their proper size possible.
"""
for key, value in state_dict.items():
if isinstance(value, torch.Tensor) and not isinstance(value, DTensor):
state_dict[key] = convert_local_tensor_to_dtensor(value, key, device_mesh, tp_plan)
return state_dict
def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh): def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh):
""" """
Add hooks to the module holding the layer. Meaning: Add hooks to the module holding the layer. Meaning:
@@ -632,13 +756,9 @@ def shard_and_distribute_module(
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
tp_plan = model._tp_plan tp_plan = model._tp_plan
module_to_tp = model.get_submodule(param_name) module_to_tp = model.get_submodule(param_name)
current_module_plan = None
rank = int(rank) rank = int(rank)
generic_param_name = re.sub(r"\d+", "*", parameter_name)
if generic_param_name in tp_plan: current_module_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
current_module_plan = tp_plan[generic_param_name]
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
current_module_plan = tp_plan[generic_param_name.rsplit(".", 1)[0]]
# Add hooks to the module if not done yet # Add hooks to the module if not done yet
# add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh) # add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)

View File

@@ -63,6 +63,9 @@ from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.tensor_parallel import ( from .integrations.tensor_parallel import (
SUPPORTED_TP_STYLES, SUPPORTED_TP_STYLES,
_get_parameter_tp_plan,
repack_weights,
replace_state_dict_local_with_dtensor,
shard_and_distribute_module, shard_and_distribute_module,
verify_tp_plan, verify_tp_plan,
) )
@@ -123,6 +126,7 @@ from .utils import (
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
from .utils.import_utils import ( from .utils.import_utils import (
ENV_VARS_TRUE_VALUES, ENV_VARS_TRUE_VALUES,
is_huggingface_hub_greater_or_equal,
is_sagemaker_mp_enabled, is_sagemaker_mp_enabled,
is_torch_fx_proxy, is_torch_fx_proxy,
is_torchdynamo_compiling, is_torchdynamo_compiling,
@@ -168,6 +172,9 @@ _is_quantized = False
_is_ds_init_called = False _is_ds_init_called = False
_torch_distributed_available = torch.distributed.is_available() _torch_distributed_available = torch.distributed.is_available()
if _torch_distributed_available and is_torch_greater_or_equal("2.5"):
from torch.distributed.tensor import DTensor
def is_fsdp_enabled(): def is_fsdp_enabled():
return ( return (
@@ -3413,6 +3420,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
if safe_serialization and not is_safetensors_available(): if safe_serialization and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.") raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
# we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one
if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"):
raise ImportError(
"Saving a model with tensor parallelism requires `huggingface_hub` version 0.31.4 or higher."
)
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return return
@@ -3540,6 +3553,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Rename state_dict keys before saving to file. Do nothing unless overridden in a particular model. # Rename state_dict keys before saving to file. Do nothing unless overridden in a particular model.
# (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm) # (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm)
state_dict = self._fix_state_dict_keys_on_save(state_dict) state_dict = self._fix_state_dict_keys_on_save(state_dict)
# If model was sharded, we cannot properly determine sizes of tensors that `local_*` strategy was used,
# therefore we replace them with DTensors that are equivalently sharded
if self._tp_size is not None:
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
if safe_serialization: if safe_serialization:
# Safetensors does not allow tensor aliasing. # Safetensors does not allow tensor aliasing.
@@ -3548,7 +3565,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
for name, tensor in state_dict.items(): for name, tensor in state_dict.items():
# Sometimes in the state_dict we have non-tensor objects. # Sometimes in the state_dict we have non-tensor objects.
# e.g. in bitsandbytes we have some `str` objects in the state_dict # e.g. in bitsandbytes we have some `str` objects in the state_dict
if isinstance(tensor, torch.Tensor): if isinstance(tensor, torch.Tensor) or isinstance(tensor, DTensor):
ptrs[id_tensor_storage(tensor)].append(name) ptrs[id_tensor_storage(tensor)].append(name)
else: else:
# In the non-tensor case, fall back to the pointer of the object itself # In the non-tensor case, fall back to the pointer of the object itself
@@ -3658,7 +3675,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
for shard_file, tensors in filename_to_tensors: for shard_file, tensors in filename_to_tensors:
shard = {} shard = {}
for tensor in tensors: for tensor in tensors:
shard[tensor] = state_dict[tensor].contiguous() if isinstance(state_dict[tensor], DTensor):
full_tensor = state_dict[tensor].full_tensor()
# to get the correctly ordered tensor we need to repack if packed
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly
else:
shard[tensor] = state_dict[tensor].contiguous()
# delete reference, see https://github.com/huggingface/transformers/pull/34890 # delete reference, see https://github.com/huggingface/transformers/pull/34890
del state_dict[tensor] del state_dict[tensor]
@@ -4606,6 +4630,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# record tp degree the model sharded to # record tp degree the model sharded to
model._tp_size = tp_size model._tp_size = tp_size
model._device_mesh = device_mesh
# make sure token embedding weights are still tied if needed # make sure token embedding weights are still tied if needed
model.tie_weights() model.tie_weights()

View File

@@ -296,6 +296,13 @@ def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id. non-overlapping lifetimes may have the same id.
""" """
if is_torch_greater_or_equal_than_2_0:
from torch.distributed.tensor import DTensor
if isinstance(tensor, DTensor):
local_tensor = tensor.to_local()
return tensor.device, local_tensor.storage().data_ptr(), tensor.nbytes
if tensor.device.type == "xla" and is_torch_xla_available(): if tensor.device.type == "xla" and is_torch_xla_available():
# NOTE: xla tensors dont have storage # NOTE: xla tensors dont have storage
# use some other unique id to distinguish. # use some other unique id to distinguish.

View File

@@ -97,6 +97,7 @@ from .utils import (
is_grokadamw_available, is_grokadamw_available,
is_hadamard_available, is_hadamard_available,
is_hqq_available, is_hqq_available,
is_huggingface_hub_greater_or_equal,
is_ipex_available, is_ipex_available,
is_jieba_available, is_jieba_available,
is_jinja_available, is_jinja_available,
@@ -542,6 +543,21 @@ def require_torch_greater_or_equal(version: str):
return decorator return decorator
def require_huggingface_hub_greater_or_equal(version: str):
"""
Decorator marking a test that requires huggingface_hub version >= `version`.
These tests are skipped when huggingface_hub version is less than `version`.
"""
def decorator(test_case):
return unittest.skipUnless(
is_huggingface_hub_greater_or_equal(version), f"test requires huggingface_hub version >= {version}"
)(test_case)
return decorator
def require_flash_attn(test_case): def require_flash_attn(test_case):
""" """
Decorator marking a test that requires Flash Attention. Decorator marking a test that requires Flash Attention.

View File

@@ -167,6 +167,7 @@ from .import_utils import (
is_habana_gaudi1, is_habana_gaudi1,
is_hadamard_available, is_hadamard_available,
is_hqq_available, is_hqq_available,
is_huggingface_hub_greater_or_equal,
is_in_notebook, is_in_notebook,
is_ipex_available, is_ipex_available,
is_jieba_available, is_jieba_available,

View File

@@ -1077,6 +1077,19 @@ def is_torch_greater_or_equal(library_version: str, accept_dev: bool = False):
return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version) return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version)
@lru_cache()
def is_huggingface_hub_greater_or_equal(library_version: str, accept_dev: bool = False):
if not _is_package_available("huggingface_hub"):
return False
if accept_dev:
return version.parse(
version.parse(importlib.metadata.version("huggingface_hub")).base_version
) >= version.parse(library_version)
else:
return version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(library_version)
def is_torchdistx_available(): def is_torchdistx_available():
return _torchdistx_available return _torchdistx_available

View File

@@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import subprocess import subprocess
import tempfile import tempfile
import textwrap import textwrap
from transformers import is_torch_available from transformers import is_torch_available
from transformers.integrations.tensor_parallel import get_packed_weights, repack_weights
from transformers.testing_utils import ( from transformers.testing_utils import (
TestCasePlus, TestCasePlus,
get_torch_dist_unique_port, get_torch_dist_unique_port,
require_huggingface_hub_greater_or_equal,
require_torch_multi_gpu, require_torch_multi_gpu,
) )
@@ -28,19 +31,51 @@ if is_torch_available():
import torch import torch
class TestTensorParallelUtils(TestCasePlus):
def test_packed_unpacked_conversion(self):
WORLD_SIZE = 2
PACKED_BLOCK_SIZE = 800
SHARDING_DIM = 2
NUM_BLOCKS = 2
original_packed_weights = torch.randn(4, 512, 2 * PACKED_BLOCK_SIZE)
original_packed_weights.get_dtype = lambda: "F32" # get_packed_weights expects PySlice object
empty_param = torch.empty(4, 512, 2 * PACKED_BLOCK_SIZE)
class MockDeviceMesh:
def size(self):
return WORLD_SIZE
mock_mesh = (
MockDeviceMesh()
) # get_packed_weights only calls `.size()`, do this to avoid doing actual distributed run
packed_weights_0 = get_packed_weights(original_packed_weights, empty_param, mock_mesh, 0, SHARDING_DIM)
packed_weights_1 = get_packed_weights(original_packed_weights, empty_param, mock_mesh, 1, SHARDING_DIM)
# simulate all gather of sharded weights
packed_weights = torch.cat([packed_weights_0, packed_weights_1], dim=SHARDING_DIM)
unpacked_weights = repack_weights(packed_weights, SHARDING_DIM, WORLD_SIZE, NUM_BLOCKS)
assert torch.allclose(unpacked_weights, original_packed_weights)
# RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py # RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
class TestTensorParallel(TestCasePlus): class TestTensorParallel(TestCasePlus):
nproc_per_node = 2 nproc_per_node = 2
def torchrun(self, script: str): def torchrun(self, script: str, is_torchrun: bool = True):
"""Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary.""" """Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary."""
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp: with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
tmp.write(script) tmp.write(script)
tmp.flush() tmp.flush()
tmp.seek(0) tmp.seek(0)
cmd = ( if is_torchrun:
f"torchrun --nproc_per_node {self.nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}" cmd = (
).split() f"torchrun --nproc_per_node {self.nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}"
).split()
else:
cmd = ["python", tmp.name]
# Note that the subprocess will be waited for here, and raise an error if not successful # Note that the subprocess will be waited for here, and raise an error if not successful
try: try:
@@ -88,6 +123,48 @@ class TestTensorParallel(TestCasePlus):
) )
self.torchrun(script_to_run) self.torchrun(script_to_run)
@require_huggingface_hub_greater_or_equal("0.31.4")
def test_model_save(self):
from safetensors import safe_open
with tempfile.TemporaryDirectory() as tmp_dir:
for is_torchrun in [True, False]:
script_to_run = textwrap.dedent(
f"""
import torch
import os
from transformers import AutoModelForCausalLM
model_id = "JackFram/llama-68m"
kwargs = dict()
if os.environ.get("RANK", None) is not None:
kwargs["tp_plan"] = "auto"
result_dir = "{tmp_dir}/tp"
else:
result_dir = "{tmp_dir}/nontp"
model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
model.save_pretrained(result_dir)
"""
)
self.torchrun(script_to_run, is_torchrun=is_torchrun)
non_tp_model_path = os.path.join(tmp_dir, "nontp")
tp_model_path = os.path.join(tmp_dir, "tp")
for filename in os.listdir(non_tp_model_path):
if not filename.endswith(".safetensors"):
continue
non_tp_model = safe_open(os.path.join(non_tp_model_path, filename), device="cpu", framework="pt")
tp_model = safe_open(os.path.join(tp_model_path, filename), device="cpu", framework="pt")
for non_tp_key in non_tp_model.keys():
non_tp_tensor = non_tp_model.get_tensor(non_tp_key)
tp_tensor = tp_model.get_tensor(non_tp_key)
assert torch.allclose(non_tp_tensor, tp_tensor), f"Tensor with key: {non_tp_key} does not match"
del non_tp_tensor, tp_tensor
@require_torch_multi_gpu @require_torch_multi_gpu
class TestTensorParallelCuda(TestTensorParallel): class TestTensorParallelCuda(TestTensorParallel):