Feat: save_pretrained for tensor parallel (and other parallelisms) models (#37919)
* tmp: initial save pretrained with dtensors * Feat: add correctness tests * Refactor: version checks * Temp: 1:1 checkpoint llama4 * refactor * Tests * Feat: works * Style * Feat: version checks + minor fixes * Style * Fix: version checks in tests * Feat: move more stuff into tensor_parallel.py
This commit is contained in:
@@ -61,6 +61,22 @@ def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> Li
|
||||
return [single_size] * blocks
|
||||
|
||||
|
||||
def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str]) -> Optional[str]:
|
||||
"""
|
||||
Get the TP style for a parameter from the TP plan.
|
||||
|
||||
The TP plan is a dictionary that maps parameter names to TP styles.
|
||||
The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight").
|
||||
"""
|
||||
generic_param_name = re.sub(r"\d+", "*", parameter_name)
|
||||
if generic_param_name in tp_plan:
|
||||
return tp_plan[generic_param_name]
|
||||
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
|
||||
return tp_plan[generic_param_name.rsplit(".", 1)[0]]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
str_to_torch_dtype = {
|
||||
"BOOL": torch.bool,
|
||||
"U8": torch.uint8,
|
||||
@@ -138,6 +154,71 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
|
||||
return tensor.to(str_to_torch_dtype[slice_dtype])
|
||||
|
||||
|
||||
def repack_weights(
|
||||
packed_parameter: torch.Tensor,
|
||||
sharded_dim: int, # The dimension index in the global tensor that was sharded
|
||||
world_size: int,
|
||||
num_blocks: int = 2,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format.
|
||||
|
||||
For example, if a weight was packed (e.g., gate_proj and up_proj) and then sharded,
|
||||
DTensor.full_tensor() might produce an interleaved layout like [G0, U0, G1, U1, ...]
|
||||
along the sharded dimension. This function reorders it to [G0, G1, ..., U0, U1, ...].
|
||||
This is an inverse operation to get_packed_weights.
|
||||
|
||||
Args:
|
||||
reconstructed_tensor: The tensor reconstructed from DTensor (e.g., via .full_tensor().contiguous()).
|
||||
sharded_dim: The dimension index in the reconstructed_tensor that was originally sharded.
|
||||
world_size: The tensor parallel world size.
|
||||
num_packed_projs: The number of projections that were packed together (e.g., 2 for gate_up_proj).
|
||||
|
||||
Returns:
|
||||
The reordered tensor in canonical packed format.
|
||||
"""
|
||||
|
||||
if num_blocks != 2:
|
||||
raise ValueError(
|
||||
"Num blocks different from 2 is not supported yet. This is most likely a bug in your implementation as we only pack gate and up projections together."
|
||||
)
|
||||
|
||||
actual_sharded_dim = sharded_dim if sharded_dim >= 0 else sharded_dim + packed_parameter.ndim
|
||||
total_size_on_sharded_dim = packed_parameter.shape[actual_sharded_dim]
|
||||
original_block_size_on_dim = total_size_on_sharded_dim // num_blocks
|
||||
shard_chunk_size = original_block_size_on_dim // world_size
|
||||
|
||||
prefix_shape = packed_parameter.shape[:actual_sharded_dim]
|
||||
suffix_shape = packed_parameter.shape[actual_sharded_dim + 1 :]
|
||||
|
||||
tensor_view = packed_parameter.view(
|
||||
*prefix_shape,
|
||||
world_size,
|
||||
num_blocks,
|
||||
shard_chunk_size,
|
||||
*suffix_shape,
|
||||
)
|
||||
|
||||
# Permute to bring num_packed_projs first, then world_size, then shard_chunk_size
|
||||
# This groups all chunks of G together, then all chunks of U together.
|
||||
# Target order of these middle dimensions: (num_packed_projs, world_size, shard_chunk_size)
|
||||
# Current order of view's middle dimensions: (world_size, num_packed_projs, shard_chunk_size)
|
||||
# Absolute indices of the dimensions to be permuted (world_size, num_packed_projs)
|
||||
axis_ws_abs = len(prefix_shape)
|
||||
axis_npp_abs = len(prefix_shape) + 1
|
||||
|
||||
permute_order = list(range(tensor_view.ndim))
|
||||
permute_order[axis_ws_abs], permute_order[axis_npp_abs] = permute_order[axis_npp_abs], permute_order[axis_ws_abs]
|
||||
|
||||
tensor_permuted = tensor_view.permute(*permute_order)
|
||||
|
||||
# Reshape back to the original tensor's ndim, with the sharded dimension now correctly ordered as [G_all, U_all].
|
||||
# The final shape should be the same as reconstructed_tensor.
|
||||
final_ordered_tensor = tensor_permuted.reshape_as(packed_parameter)
|
||||
|
||||
return final_ordered_tensor
|
||||
|
||||
|
||||
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||
if dim == 0:
|
||||
size_ = empty_param.shape[0]
|
||||
@@ -578,6 +659,49 @@ def translate_to_torch_parallel_style(style: str):
|
||||
raise ValueError(f"Unsupported parallel style value: {style}")
|
||||
|
||||
|
||||
def convert_local_tensor_to_dtensor(
|
||||
parameter: torch.Tensor, parameter_name: str, device_mesh, tp_plan: dict[str, str]
|
||||
) -> DTensor:
|
||||
"""
|
||||
Converts a local variant of weights to a DTensor with corresponding placements. Shouldn't be done ever except of before saving the model.
|
||||
"""
|
||||
_, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
|
||||
tp_style = _get_parameter_tp_plan(parameter_name, tp_plan)
|
||||
if not tp_style:
|
||||
return parameter
|
||||
|
||||
if tp_style not in ["local_packed_rowwise", "local_rowwise", "local_colwise"]:
|
||||
return parameter
|
||||
# TODO: this logic should be wrapped in a function, this is copied from corresponding tp classes.
|
||||
if tp_style == "local_packed_rowwise":
|
||||
placements = [Shard(-1)]
|
||||
elif tp_style == "local_rowwise":
|
||||
if param_type == "bias":
|
||||
placements = [Replicate()]
|
||||
else:
|
||||
placements = [Shard(-1)]
|
||||
elif tp_style == "local_colwise":
|
||||
if param_type == "bias":
|
||||
placements = [Shard(-1)]
|
||||
else:
|
||||
placements = [Shard(-2)]
|
||||
return DTensor.from_local(parameter, device_mesh, placements, run_check=False)
|
||||
|
||||
|
||||
def replace_state_dict_local_with_dtensor(
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
tp_plan: dict[str, str],
|
||||
device_mesh,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Replaces all tensors that were sharded with `local_*` strategy with DTensor to make determining their proper size possible.
|
||||
"""
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(value, torch.Tensor) and not isinstance(value, DTensor):
|
||||
state_dict[key] = convert_local_tensor_to_dtensor(value, key, device_mesh, tp_plan)
|
||||
return state_dict
|
||||
|
||||
|
||||
def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh):
|
||||
"""
|
||||
Add hooks to the module holding the layer. Meaning:
|
||||
@@ -632,13 +756,9 @@ def shard_and_distribute_module(
|
||||
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
|
||||
tp_plan = model._tp_plan
|
||||
module_to_tp = model.get_submodule(param_name)
|
||||
current_module_plan = None
|
||||
rank = int(rank)
|
||||
generic_param_name = re.sub(r"\d+", "*", parameter_name)
|
||||
if generic_param_name in tp_plan:
|
||||
current_module_plan = tp_plan[generic_param_name]
|
||||
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
|
||||
current_module_plan = tp_plan[generic_param_name.rsplit(".", 1)[0]]
|
||||
|
||||
current_module_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
|
||||
|
||||
# Add hooks to the module if not done yet
|
||||
# add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
|
||||
|
||||
@@ -63,6 +63,9 @@ from .integrations.flex_attention import flex_attention_forward
|
||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||
from .integrations.tensor_parallel import (
|
||||
SUPPORTED_TP_STYLES,
|
||||
_get_parameter_tp_plan,
|
||||
repack_weights,
|
||||
replace_state_dict_local_with_dtensor,
|
||||
shard_and_distribute_module,
|
||||
verify_tp_plan,
|
||||
)
|
||||
@@ -123,6 +126,7 @@ from .utils import (
|
||||
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
|
||||
from .utils.import_utils import (
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
is_huggingface_hub_greater_or_equal,
|
||||
is_sagemaker_mp_enabled,
|
||||
is_torch_fx_proxy,
|
||||
is_torchdynamo_compiling,
|
||||
@@ -168,6 +172,9 @@ _is_quantized = False
|
||||
_is_ds_init_called = False
|
||||
_torch_distributed_available = torch.distributed.is_available()
|
||||
|
||||
if _torch_distributed_available and is_torch_greater_or_equal("2.5"):
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
|
||||
def is_fsdp_enabled():
|
||||
return (
|
||||
@@ -3413,6 +3420,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
if safe_serialization and not is_safetensors_available():
|
||||
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
|
||||
|
||||
# we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one
|
||||
if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"):
|
||||
raise ImportError(
|
||||
"Saving a model with tensor parallelism requires `huggingface_hub` version 0.31.4 or higher."
|
||||
)
|
||||
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
@@ -3540,6 +3553,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# Rename state_dict keys before saving to file. Do nothing unless overridden in a particular model.
|
||||
# (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm)
|
||||
state_dict = self._fix_state_dict_keys_on_save(state_dict)
|
||||
# If model was sharded, we cannot properly determine sizes of tensors that `local_*` strategy was used,
|
||||
# therefore we replace them with DTensors that are equivalently sharded
|
||||
if self._tp_size is not None:
|
||||
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
|
||||
|
||||
if safe_serialization:
|
||||
# Safetensors does not allow tensor aliasing.
|
||||
@@ -3548,7 +3565,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
for name, tensor in state_dict.items():
|
||||
# Sometimes in the state_dict we have non-tensor objects.
|
||||
# e.g. in bitsandbytes we have some `str` objects in the state_dict
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
if isinstance(tensor, torch.Tensor) or isinstance(tensor, DTensor):
|
||||
ptrs[id_tensor_storage(tensor)].append(name)
|
||||
else:
|
||||
# In the non-tensor case, fall back to the pointer of the object itself
|
||||
@@ -3658,6 +3675,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
for shard_file, tensors in filename_to_tensors:
|
||||
shard = {}
|
||||
for tensor in tensors:
|
||||
if isinstance(state_dict[tensor], DTensor):
|
||||
full_tensor = state_dict[tensor].full_tensor()
|
||||
# to get the correctly ordered tensor we need to repack if packed
|
||||
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
|
||||
full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
|
||||
shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly
|
||||
else:
|
||||
shard[tensor] = state_dict[tensor].contiguous()
|
||||
# delete reference, see https://github.com/huggingface/transformers/pull/34890
|
||||
del state_dict[tensor]
|
||||
@@ -4606,6 +4630,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
# record tp degree the model sharded to
|
||||
model._tp_size = tp_size
|
||||
model._device_mesh = device_mesh
|
||||
|
||||
# make sure token embedding weights are still tied if needed
|
||||
model.tie_weights()
|
||||
|
||||
@@ -296,6 +296,13 @@ def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
|
||||
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
|
||||
non-overlapping lifetimes may have the same id.
|
||||
"""
|
||||
if is_torch_greater_or_equal_than_2_0:
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
if isinstance(tensor, DTensor):
|
||||
local_tensor = tensor.to_local()
|
||||
return tensor.device, local_tensor.storage().data_ptr(), tensor.nbytes
|
||||
|
||||
if tensor.device.type == "xla" and is_torch_xla_available():
|
||||
# NOTE: xla tensors dont have storage
|
||||
# use some other unique id to distinguish.
|
||||
|
||||
@@ -97,6 +97,7 @@ from .utils import (
|
||||
is_grokadamw_available,
|
||||
is_hadamard_available,
|
||||
is_hqq_available,
|
||||
is_huggingface_hub_greater_or_equal,
|
||||
is_ipex_available,
|
||||
is_jieba_available,
|
||||
is_jinja_available,
|
||||
@@ -542,6 +543,21 @@ def require_torch_greater_or_equal(version: str):
|
||||
return decorator
|
||||
|
||||
|
||||
def require_huggingface_hub_greater_or_equal(version: str):
|
||||
"""
|
||||
Decorator marking a test that requires huggingface_hub version >= `version`.
|
||||
|
||||
These tests are skipped when huggingface_hub version is less than `version`.
|
||||
"""
|
||||
|
||||
def decorator(test_case):
|
||||
return unittest.skipUnless(
|
||||
is_huggingface_hub_greater_or_equal(version), f"test requires huggingface_hub version >= {version}"
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_flash_attn(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires Flash Attention.
|
||||
|
||||
@@ -167,6 +167,7 @@ from .import_utils import (
|
||||
is_habana_gaudi1,
|
||||
is_hadamard_available,
|
||||
is_hqq_available,
|
||||
is_huggingface_hub_greater_or_equal,
|
||||
is_in_notebook,
|
||||
is_ipex_available,
|
||||
is_jieba_available,
|
||||
|
||||
@@ -1077,6 +1077,19 @@ def is_torch_greater_or_equal(library_version: str, accept_dev: bool = False):
|
||||
return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def is_huggingface_hub_greater_or_equal(library_version: str, accept_dev: bool = False):
|
||||
if not _is_package_available("huggingface_hub"):
|
||||
return False
|
||||
|
||||
if accept_dev:
|
||||
return version.parse(
|
||||
version.parse(importlib.metadata.version("huggingface_hub")).base_version
|
||||
) >= version.parse(library_version)
|
||||
else:
|
||||
return version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(library_version)
|
||||
|
||||
|
||||
def is_torchdistx_available():
|
||||
return _torchdistx_available
|
||||
|
||||
|
||||
@@ -12,14 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import textwrap
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.integrations.tensor_parallel import get_packed_weights, repack_weights
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
get_torch_dist_unique_port,
|
||||
require_huggingface_hub_greater_or_equal,
|
||||
require_torch_multi_gpu,
|
||||
)
|
||||
|
||||
@@ -28,19 +31,51 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class TestTensorParallelUtils(TestCasePlus):
|
||||
def test_packed_unpacked_conversion(self):
|
||||
WORLD_SIZE = 2
|
||||
PACKED_BLOCK_SIZE = 800
|
||||
SHARDING_DIM = 2
|
||||
NUM_BLOCKS = 2
|
||||
|
||||
original_packed_weights = torch.randn(4, 512, 2 * PACKED_BLOCK_SIZE)
|
||||
original_packed_weights.get_dtype = lambda: "F32" # get_packed_weights expects PySlice object
|
||||
empty_param = torch.empty(4, 512, 2 * PACKED_BLOCK_SIZE)
|
||||
|
||||
class MockDeviceMesh:
|
||||
def size(self):
|
||||
return WORLD_SIZE
|
||||
|
||||
mock_mesh = (
|
||||
MockDeviceMesh()
|
||||
) # get_packed_weights only calls `.size()`, do this to avoid doing actual distributed run
|
||||
|
||||
packed_weights_0 = get_packed_weights(original_packed_weights, empty_param, mock_mesh, 0, SHARDING_DIM)
|
||||
packed_weights_1 = get_packed_weights(original_packed_weights, empty_param, mock_mesh, 1, SHARDING_DIM)
|
||||
|
||||
# simulate all gather of sharded weights
|
||||
packed_weights = torch.cat([packed_weights_0, packed_weights_1], dim=SHARDING_DIM)
|
||||
unpacked_weights = repack_weights(packed_weights, SHARDING_DIM, WORLD_SIZE, NUM_BLOCKS)
|
||||
|
||||
assert torch.allclose(unpacked_weights, original_packed_weights)
|
||||
|
||||
|
||||
# RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
|
||||
class TestTensorParallel(TestCasePlus):
|
||||
nproc_per_node = 2
|
||||
|
||||
def torchrun(self, script: str):
|
||||
def torchrun(self, script: str, is_torchrun: bool = True):
|
||||
"""Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary."""
|
||||
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
|
||||
tmp.write(script)
|
||||
tmp.flush()
|
||||
tmp.seek(0)
|
||||
if is_torchrun:
|
||||
cmd = (
|
||||
f"torchrun --nproc_per_node {self.nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}"
|
||||
).split()
|
||||
else:
|
||||
cmd = ["python", tmp.name]
|
||||
|
||||
# Note that the subprocess will be waited for here, and raise an error if not successful
|
||||
try:
|
||||
@@ -88,6 +123,48 @@ class TestTensorParallel(TestCasePlus):
|
||||
)
|
||||
self.torchrun(script_to_run)
|
||||
|
||||
@require_huggingface_hub_greater_or_equal("0.31.4")
|
||||
def test_model_save(self):
|
||||
from safetensors import safe_open
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
for is_torchrun in [True, False]:
|
||||
script_to_run = textwrap.dedent(
|
||||
f"""
|
||||
import torch
|
||||
import os
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model_id = "JackFram/llama-68m"
|
||||
kwargs = dict()
|
||||
|
||||
if os.environ.get("RANK", None) is not None:
|
||||
kwargs["tp_plan"] = "auto"
|
||||
result_dir = "{tmp_dir}/tp"
|
||||
else:
|
||||
result_dir = "{tmp_dir}/nontp"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
|
||||
model.save_pretrained(result_dir)
|
||||
"""
|
||||
)
|
||||
self.torchrun(script_to_run, is_torchrun=is_torchrun)
|
||||
|
||||
non_tp_model_path = os.path.join(tmp_dir, "nontp")
|
||||
tp_model_path = os.path.join(tmp_dir, "tp")
|
||||
|
||||
for filename in os.listdir(non_tp_model_path):
|
||||
if not filename.endswith(".safetensors"):
|
||||
continue
|
||||
|
||||
non_tp_model = safe_open(os.path.join(non_tp_model_path, filename), device="cpu", framework="pt")
|
||||
tp_model = safe_open(os.path.join(tp_model_path, filename), device="cpu", framework="pt")
|
||||
for non_tp_key in non_tp_model.keys():
|
||||
non_tp_tensor = non_tp_model.get_tensor(non_tp_key)
|
||||
tp_tensor = tp_model.get_tensor(non_tp_key)
|
||||
assert torch.allclose(non_tp_tensor, tp_tensor), f"Tensor with key: {non_tp_key} does not match"
|
||||
del non_tp_tensor, tp_tensor
|
||||
|
||||
|
||||
@require_torch_multi_gpu
|
||||
class TestTensorParallelCuda(TestTensorParallel):
|
||||
|
||||
Reference in New Issue
Block a user