Feat: save_pretrained for tensor parallel (and other parallelisms) models (#37919)
* tmp: initial save pretrained with dtensors * Feat: add correctness tests * Refactor: version checks * Temp: 1:1 checkpoint llama4 * refactor * Tests * Feat: works * Style * Feat: version checks + minor fixes * Style * Fix: version checks in tests * Feat: move more stuff into tensor_parallel.py
This commit is contained in:
@@ -61,6 +61,22 @@ def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> Li
|
|||||||
return [single_size] * blocks
|
return [single_size] * blocks
|
||||||
|
|
||||||
|
|
||||||
|
def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str]) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the TP style for a parameter from the TP plan.
|
||||||
|
|
||||||
|
The TP plan is a dictionary that maps parameter names to TP styles.
|
||||||
|
The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight").
|
||||||
|
"""
|
||||||
|
generic_param_name = re.sub(r"\d+", "*", parameter_name)
|
||||||
|
if generic_param_name in tp_plan:
|
||||||
|
return tp_plan[generic_param_name]
|
||||||
|
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
|
||||||
|
return tp_plan[generic_param_name.rsplit(".", 1)[0]]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
str_to_torch_dtype = {
|
str_to_torch_dtype = {
|
||||||
"BOOL": torch.bool,
|
"BOOL": torch.bool,
|
||||||
"U8": torch.uint8,
|
"U8": torch.uint8,
|
||||||
@@ -138,6 +154,71 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
|
|||||||
return tensor.to(str_to_torch_dtype[slice_dtype])
|
return tensor.to(str_to_torch_dtype[slice_dtype])
|
||||||
|
|
||||||
|
|
||||||
|
def repack_weights(
|
||||||
|
packed_parameter: torch.Tensor,
|
||||||
|
sharded_dim: int, # The dimension index in the global tensor that was sharded
|
||||||
|
world_size: int,
|
||||||
|
num_blocks: int = 2,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format.
|
||||||
|
|
||||||
|
For example, if a weight was packed (e.g., gate_proj and up_proj) and then sharded,
|
||||||
|
DTensor.full_tensor() might produce an interleaved layout like [G0, U0, G1, U1, ...]
|
||||||
|
along the sharded dimension. This function reorders it to [G0, G1, ..., U0, U1, ...].
|
||||||
|
This is an inverse operation to get_packed_weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reconstructed_tensor: The tensor reconstructed from DTensor (e.g., via .full_tensor().contiguous()).
|
||||||
|
sharded_dim: The dimension index in the reconstructed_tensor that was originally sharded.
|
||||||
|
world_size: The tensor parallel world size.
|
||||||
|
num_packed_projs: The number of projections that were packed together (e.g., 2 for gate_up_proj).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The reordered tensor in canonical packed format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if num_blocks != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"Num blocks different from 2 is not supported yet. This is most likely a bug in your implementation as we only pack gate and up projections together."
|
||||||
|
)
|
||||||
|
|
||||||
|
actual_sharded_dim = sharded_dim if sharded_dim >= 0 else sharded_dim + packed_parameter.ndim
|
||||||
|
total_size_on_sharded_dim = packed_parameter.shape[actual_sharded_dim]
|
||||||
|
original_block_size_on_dim = total_size_on_sharded_dim // num_blocks
|
||||||
|
shard_chunk_size = original_block_size_on_dim // world_size
|
||||||
|
|
||||||
|
prefix_shape = packed_parameter.shape[:actual_sharded_dim]
|
||||||
|
suffix_shape = packed_parameter.shape[actual_sharded_dim + 1 :]
|
||||||
|
|
||||||
|
tensor_view = packed_parameter.view(
|
||||||
|
*prefix_shape,
|
||||||
|
world_size,
|
||||||
|
num_blocks,
|
||||||
|
shard_chunk_size,
|
||||||
|
*suffix_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Permute to bring num_packed_projs first, then world_size, then shard_chunk_size
|
||||||
|
# This groups all chunks of G together, then all chunks of U together.
|
||||||
|
# Target order of these middle dimensions: (num_packed_projs, world_size, shard_chunk_size)
|
||||||
|
# Current order of view's middle dimensions: (world_size, num_packed_projs, shard_chunk_size)
|
||||||
|
# Absolute indices of the dimensions to be permuted (world_size, num_packed_projs)
|
||||||
|
axis_ws_abs = len(prefix_shape)
|
||||||
|
axis_npp_abs = len(prefix_shape) + 1
|
||||||
|
|
||||||
|
permute_order = list(range(tensor_view.ndim))
|
||||||
|
permute_order[axis_ws_abs], permute_order[axis_npp_abs] = permute_order[axis_npp_abs], permute_order[axis_ws_abs]
|
||||||
|
|
||||||
|
tensor_permuted = tensor_view.permute(*permute_order)
|
||||||
|
|
||||||
|
# Reshape back to the original tensor's ndim, with the sharded dimension now correctly ordered as [G_all, U_all].
|
||||||
|
# The final shape should be the same as reconstructed_tensor.
|
||||||
|
final_ordered_tensor = tensor_permuted.reshape_as(packed_parameter)
|
||||||
|
|
||||||
|
return final_ordered_tensor
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||||
if dim == 0:
|
if dim == 0:
|
||||||
size_ = empty_param.shape[0]
|
size_ = empty_param.shape[0]
|
||||||
@@ -578,6 +659,49 @@ def translate_to_torch_parallel_style(style: str):
|
|||||||
raise ValueError(f"Unsupported parallel style value: {style}")
|
raise ValueError(f"Unsupported parallel style value: {style}")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_local_tensor_to_dtensor(
|
||||||
|
parameter: torch.Tensor, parameter_name: str, device_mesh, tp_plan: dict[str, str]
|
||||||
|
) -> DTensor:
|
||||||
|
"""
|
||||||
|
Converts a local variant of weights to a DTensor with corresponding placements. Shouldn't be done ever except of before saving the model.
|
||||||
|
"""
|
||||||
|
_, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
|
||||||
|
tp_style = _get_parameter_tp_plan(parameter_name, tp_plan)
|
||||||
|
if not tp_style:
|
||||||
|
return parameter
|
||||||
|
|
||||||
|
if tp_style not in ["local_packed_rowwise", "local_rowwise", "local_colwise"]:
|
||||||
|
return parameter
|
||||||
|
# TODO: this logic should be wrapped in a function, this is copied from corresponding tp classes.
|
||||||
|
if tp_style == "local_packed_rowwise":
|
||||||
|
placements = [Shard(-1)]
|
||||||
|
elif tp_style == "local_rowwise":
|
||||||
|
if param_type == "bias":
|
||||||
|
placements = [Replicate()]
|
||||||
|
else:
|
||||||
|
placements = [Shard(-1)]
|
||||||
|
elif tp_style == "local_colwise":
|
||||||
|
if param_type == "bias":
|
||||||
|
placements = [Shard(-1)]
|
||||||
|
else:
|
||||||
|
placements = [Shard(-2)]
|
||||||
|
return DTensor.from_local(parameter, device_mesh, placements, run_check=False)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_state_dict_local_with_dtensor(
|
||||||
|
state_dict: dict[str, torch.Tensor],
|
||||||
|
tp_plan: dict[str, str],
|
||||||
|
device_mesh,
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Replaces all tensors that were sharded with `local_*` strategy with DTensor to make determining their proper size possible.
|
||||||
|
"""
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if isinstance(value, torch.Tensor) and not isinstance(value, DTensor):
|
||||||
|
state_dict[key] = convert_local_tensor_to_dtensor(value, key, device_mesh, tp_plan)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh):
|
def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh):
|
||||||
"""
|
"""
|
||||||
Add hooks to the module holding the layer. Meaning:
|
Add hooks to the module holding the layer. Meaning:
|
||||||
@@ -632,13 +756,9 @@ def shard_and_distribute_module(
|
|||||||
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
|
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
|
||||||
tp_plan = model._tp_plan
|
tp_plan = model._tp_plan
|
||||||
module_to_tp = model.get_submodule(param_name)
|
module_to_tp = model.get_submodule(param_name)
|
||||||
current_module_plan = None
|
|
||||||
rank = int(rank)
|
rank = int(rank)
|
||||||
generic_param_name = re.sub(r"\d+", "*", parameter_name)
|
|
||||||
if generic_param_name in tp_plan:
|
current_module_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
|
||||||
current_module_plan = tp_plan[generic_param_name]
|
|
||||||
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
|
|
||||||
current_module_plan = tp_plan[generic_param_name.rsplit(".", 1)[0]]
|
|
||||||
|
|
||||||
# Add hooks to the module if not done yet
|
# Add hooks to the module if not done yet
|
||||||
# add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
|
# add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
|
||||||
|
|||||||
@@ -63,6 +63,9 @@ from .integrations.flex_attention import flex_attention_forward
|
|||||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||||
from .integrations.tensor_parallel import (
|
from .integrations.tensor_parallel import (
|
||||||
SUPPORTED_TP_STYLES,
|
SUPPORTED_TP_STYLES,
|
||||||
|
_get_parameter_tp_plan,
|
||||||
|
repack_weights,
|
||||||
|
replace_state_dict_local_with_dtensor,
|
||||||
shard_and_distribute_module,
|
shard_and_distribute_module,
|
||||||
verify_tp_plan,
|
verify_tp_plan,
|
||||||
)
|
)
|
||||||
@@ -123,6 +126,7 @@ from .utils import (
|
|||||||
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
|
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
|
||||||
from .utils.import_utils import (
|
from .utils.import_utils import (
|
||||||
ENV_VARS_TRUE_VALUES,
|
ENV_VARS_TRUE_VALUES,
|
||||||
|
is_huggingface_hub_greater_or_equal,
|
||||||
is_sagemaker_mp_enabled,
|
is_sagemaker_mp_enabled,
|
||||||
is_torch_fx_proxy,
|
is_torch_fx_proxy,
|
||||||
is_torchdynamo_compiling,
|
is_torchdynamo_compiling,
|
||||||
@@ -168,6 +172,9 @@ _is_quantized = False
|
|||||||
_is_ds_init_called = False
|
_is_ds_init_called = False
|
||||||
_torch_distributed_available = torch.distributed.is_available()
|
_torch_distributed_available = torch.distributed.is_available()
|
||||||
|
|
||||||
|
if _torch_distributed_available and is_torch_greater_or_equal("2.5"):
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
|
|
||||||
def is_fsdp_enabled():
|
def is_fsdp_enabled():
|
||||||
return (
|
return (
|
||||||
@@ -3413,6 +3420,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
if safe_serialization and not is_safetensors_available():
|
if safe_serialization and not is_safetensors_available():
|
||||||
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
|
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
|
||||||
|
|
||||||
|
# we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one
|
||||||
|
if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"):
|
||||||
|
raise ImportError(
|
||||||
|
"Saving a model with tensor parallelism requires `huggingface_hub` version 0.31.4 or higher."
|
||||||
|
)
|
||||||
|
|
||||||
if os.path.isfile(save_directory):
|
if os.path.isfile(save_directory):
|
||||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
@@ -3540,6 +3553,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
# Rename state_dict keys before saving to file. Do nothing unless overridden in a particular model.
|
# Rename state_dict keys before saving to file. Do nothing unless overridden in a particular model.
|
||||||
# (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm)
|
# (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm)
|
||||||
state_dict = self._fix_state_dict_keys_on_save(state_dict)
|
state_dict = self._fix_state_dict_keys_on_save(state_dict)
|
||||||
|
# If model was sharded, we cannot properly determine sizes of tensors that `local_*` strategy was used,
|
||||||
|
# therefore we replace them with DTensors that are equivalently sharded
|
||||||
|
if self._tp_size is not None:
|
||||||
|
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
|
||||||
|
|
||||||
if safe_serialization:
|
if safe_serialization:
|
||||||
# Safetensors does not allow tensor aliasing.
|
# Safetensors does not allow tensor aliasing.
|
||||||
@@ -3548,7 +3565,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
for name, tensor in state_dict.items():
|
for name, tensor in state_dict.items():
|
||||||
# Sometimes in the state_dict we have non-tensor objects.
|
# Sometimes in the state_dict we have non-tensor objects.
|
||||||
# e.g. in bitsandbytes we have some `str` objects in the state_dict
|
# e.g. in bitsandbytes we have some `str` objects in the state_dict
|
||||||
if isinstance(tensor, torch.Tensor):
|
if isinstance(tensor, torch.Tensor) or isinstance(tensor, DTensor):
|
||||||
ptrs[id_tensor_storage(tensor)].append(name)
|
ptrs[id_tensor_storage(tensor)].append(name)
|
||||||
else:
|
else:
|
||||||
# In the non-tensor case, fall back to the pointer of the object itself
|
# In the non-tensor case, fall back to the pointer of the object itself
|
||||||
@@ -3658,7 +3675,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
for shard_file, tensors in filename_to_tensors:
|
for shard_file, tensors in filename_to_tensors:
|
||||||
shard = {}
|
shard = {}
|
||||||
for tensor in tensors:
|
for tensor in tensors:
|
||||||
shard[tensor] = state_dict[tensor].contiguous()
|
if isinstance(state_dict[tensor], DTensor):
|
||||||
|
full_tensor = state_dict[tensor].full_tensor()
|
||||||
|
# to get the correctly ordered tensor we need to repack if packed
|
||||||
|
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
|
||||||
|
full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
|
||||||
|
shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly
|
||||||
|
else:
|
||||||
|
shard[tensor] = state_dict[tensor].contiguous()
|
||||||
# delete reference, see https://github.com/huggingface/transformers/pull/34890
|
# delete reference, see https://github.com/huggingface/transformers/pull/34890
|
||||||
del state_dict[tensor]
|
del state_dict[tensor]
|
||||||
|
|
||||||
@@ -4606,6 +4630,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
|
|
||||||
# record tp degree the model sharded to
|
# record tp degree the model sharded to
|
||||||
model._tp_size = tp_size
|
model._tp_size = tp_size
|
||||||
|
model._device_mesh = device_mesh
|
||||||
|
|
||||||
# make sure token embedding weights are still tied if needed
|
# make sure token embedding weights are still tied if needed
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
|
|||||||
@@ -296,6 +296,13 @@ def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
|
|||||||
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
|
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
|
||||||
non-overlapping lifetimes may have the same id.
|
non-overlapping lifetimes may have the same id.
|
||||||
"""
|
"""
|
||||||
|
if is_torch_greater_or_equal_than_2_0:
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
|
if isinstance(tensor, DTensor):
|
||||||
|
local_tensor = tensor.to_local()
|
||||||
|
return tensor.device, local_tensor.storage().data_ptr(), tensor.nbytes
|
||||||
|
|
||||||
if tensor.device.type == "xla" and is_torch_xla_available():
|
if tensor.device.type == "xla" and is_torch_xla_available():
|
||||||
# NOTE: xla tensors dont have storage
|
# NOTE: xla tensors dont have storage
|
||||||
# use some other unique id to distinguish.
|
# use some other unique id to distinguish.
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ from .utils import (
|
|||||||
is_grokadamw_available,
|
is_grokadamw_available,
|
||||||
is_hadamard_available,
|
is_hadamard_available,
|
||||||
is_hqq_available,
|
is_hqq_available,
|
||||||
|
is_huggingface_hub_greater_or_equal,
|
||||||
is_ipex_available,
|
is_ipex_available,
|
||||||
is_jieba_available,
|
is_jieba_available,
|
||||||
is_jinja_available,
|
is_jinja_available,
|
||||||
@@ -542,6 +543,21 @@ def require_torch_greater_or_equal(version: str):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def require_huggingface_hub_greater_or_equal(version: str):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires huggingface_hub version >= `version`.
|
||||||
|
|
||||||
|
These tests are skipped when huggingface_hub version is less than `version`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(test_case):
|
||||||
|
return unittest.skipUnless(
|
||||||
|
is_huggingface_hub_greater_or_equal(version), f"test requires huggingface_hub version >= {version}"
|
||||||
|
)(test_case)
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def require_flash_attn(test_case):
|
def require_flash_attn(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires Flash Attention.
|
Decorator marking a test that requires Flash Attention.
|
||||||
|
|||||||
@@ -167,6 +167,7 @@ from .import_utils import (
|
|||||||
is_habana_gaudi1,
|
is_habana_gaudi1,
|
||||||
is_hadamard_available,
|
is_hadamard_available,
|
||||||
is_hqq_available,
|
is_hqq_available,
|
||||||
|
is_huggingface_hub_greater_or_equal,
|
||||||
is_in_notebook,
|
is_in_notebook,
|
||||||
is_ipex_available,
|
is_ipex_available,
|
||||||
is_jieba_available,
|
is_jieba_available,
|
||||||
|
|||||||
@@ -1077,6 +1077,19 @@ def is_torch_greater_or_equal(library_version: str, accept_dev: bool = False):
|
|||||||
return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version)
|
return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def is_huggingface_hub_greater_or_equal(library_version: str, accept_dev: bool = False):
|
||||||
|
if not _is_package_available("huggingface_hub"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if accept_dev:
|
||||||
|
return version.parse(
|
||||||
|
version.parse(importlib.metadata.version("huggingface_hub")).base_version
|
||||||
|
) >= version.parse(library_version)
|
||||||
|
else:
|
||||||
|
return version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(library_version)
|
||||||
|
|
||||||
|
|
||||||
def is_torchdistx_available():
|
def is_torchdistx_available():
|
||||||
return _torchdistx_available
|
return _torchdistx_available
|
||||||
|
|
||||||
|
|||||||
@@ -12,14 +12,17 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
from transformers.integrations.tensor_parallel import get_packed_weights, repack_weights
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
get_torch_dist_unique_port,
|
get_torch_dist_unique_port,
|
||||||
|
require_huggingface_hub_greater_or_equal,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,19 +31,51 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class TestTensorParallelUtils(TestCasePlus):
|
||||||
|
def test_packed_unpacked_conversion(self):
|
||||||
|
WORLD_SIZE = 2
|
||||||
|
PACKED_BLOCK_SIZE = 800
|
||||||
|
SHARDING_DIM = 2
|
||||||
|
NUM_BLOCKS = 2
|
||||||
|
|
||||||
|
original_packed_weights = torch.randn(4, 512, 2 * PACKED_BLOCK_SIZE)
|
||||||
|
original_packed_weights.get_dtype = lambda: "F32" # get_packed_weights expects PySlice object
|
||||||
|
empty_param = torch.empty(4, 512, 2 * PACKED_BLOCK_SIZE)
|
||||||
|
|
||||||
|
class MockDeviceMesh:
|
||||||
|
def size(self):
|
||||||
|
return WORLD_SIZE
|
||||||
|
|
||||||
|
mock_mesh = (
|
||||||
|
MockDeviceMesh()
|
||||||
|
) # get_packed_weights only calls `.size()`, do this to avoid doing actual distributed run
|
||||||
|
|
||||||
|
packed_weights_0 = get_packed_weights(original_packed_weights, empty_param, mock_mesh, 0, SHARDING_DIM)
|
||||||
|
packed_weights_1 = get_packed_weights(original_packed_weights, empty_param, mock_mesh, 1, SHARDING_DIM)
|
||||||
|
|
||||||
|
# simulate all gather of sharded weights
|
||||||
|
packed_weights = torch.cat([packed_weights_0, packed_weights_1], dim=SHARDING_DIM)
|
||||||
|
unpacked_weights = repack_weights(packed_weights, SHARDING_DIM, WORLD_SIZE, NUM_BLOCKS)
|
||||||
|
|
||||||
|
assert torch.allclose(unpacked_weights, original_packed_weights)
|
||||||
|
|
||||||
|
|
||||||
# RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
|
# RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
|
||||||
class TestTensorParallel(TestCasePlus):
|
class TestTensorParallel(TestCasePlus):
|
||||||
nproc_per_node = 2
|
nproc_per_node = 2
|
||||||
|
|
||||||
def torchrun(self, script: str):
|
def torchrun(self, script: str, is_torchrun: bool = True):
|
||||||
"""Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary."""
|
"""Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary."""
|
||||||
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
|
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
|
||||||
tmp.write(script)
|
tmp.write(script)
|
||||||
tmp.flush()
|
tmp.flush()
|
||||||
tmp.seek(0)
|
tmp.seek(0)
|
||||||
cmd = (
|
if is_torchrun:
|
||||||
f"torchrun --nproc_per_node {self.nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}"
|
cmd = (
|
||||||
).split()
|
f"torchrun --nproc_per_node {self.nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}"
|
||||||
|
).split()
|
||||||
|
else:
|
||||||
|
cmd = ["python", tmp.name]
|
||||||
|
|
||||||
# Note that the subprocess will be waited for here, and raise an error if not successful
|
# Note that the subprocess will be waited for here, and raise an error if not successful
|
||||||
try:
|
try:
|
||||||
@@ -88,6 +123,48 @@ class TestTensorParallel(TestCasePlus):
|
|||||||
)
|
)
|
||||||
self.torchrun(script_to_run)
|
self.torchrun(script_to_run)
|
||||||
|
|
||||||
|
@require_huggingface_hub_greater_or_equal("0.31.4")
|
||||||
|
def test_model_save(self):
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
for is_torchrun in [True, False]:
|
||||||
|
script_to_run = textwrap.dedent(
|
||||||
|
f"""
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
model_id = "JackFram/llama-68m"
|
||||||
|
kwargs = dict()
|
||||||
|
|
||||||
|
if os.environ.get("RANK", None) is not None:
|
||||||
|
kwargs["tp_plan"] = "auto"
|
||||||
|
result_dir = "{tmp_dir}/tp"
|
||||||
|
else:
|
||||||
|
result_dir = "{tmp_dir}/nontp"
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
|
||||||
|
model.save_pretrained(result_dir)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.torchrun(script_to_run, is_torchrun=is_torchrun)
|
||||||
|
|
||||||
|
non_tp_model_path = os.path.join(tmp_dir, "nontp")
|
||||||
|
tp_model_path = os.path.join(tmp_dir, "tp")
|
||||||
|
|
||||||
|
for filename in os.listdir(non_tp_model_path):
|
||||||
|
if not filename.endswith(".safetensors"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
non_tp_model = safe_open(os.path.join(non_tp_model_path, filename), device="cpu", framework="pt")
|
||||||
|
tp_model = safe_open(os.path.join(tp_model_path, filename), device="cpu", framework="pt")
|
||||||
|
for non_tp_key in non_tp_model.keys():
|
||||||
|
non_tp_tensor = non_tp_model.get_tensor(non_tp_key)
|
||||||
|
tp_tensor = tp_model.get_tensor(non_tp_key)
|
||||||
|
assert torch.allclose(non_tp_tensor, tp_tensor), f"Tensor with key: {non_tp_key} does not match"
|
||||||
|
del non_tp_tensor, tp_tensor
|
||||||
|
|
||||||
|
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
class TestTensorParallelCuda(TestTensorParallel):
|
class TestTensorParallelCuda(TestTensorParallel):
|
||||||
|
|||||||
Reference in New Issue
Block a user