From aff7df8436dde04762170d3d0fbe906c7216d6f2 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 10 Jul 2025 05:14:45 +0800 Subject: [PATCH] enable static cache on TP model (#39164) * enable static cache on TP model Signed-off-by: jiqing-feng * check tp size before init kv cache Signed-off-by: jiqing-feng * fix docstring Signed-off-by: jiqing-feng * add tp tests Signed-off-by: jiqing-feng * fix comment Signed-off-by: jiqing-feng * fix other cache head size Signed-off-by: jiqing-feng --------- Signed-off-by: jiqing-feng --- src/transformers/cache_utils.py | 36 ++++++++++++++++ src/transformers/generation/utils.py | 3 ++ src/transformers/modeling_utils.py | 3 ++ tests/tensor_parallel/test_tensor_parallel.py | 43 ++++++++++++++++++- 4 files changed, 84 insertions(+), 1 deletion(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 04ccc6f7ef..bfba7e2bbd 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1098,6 +1098,10 @@ class StaticCache(Cache): Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is split between different gpus. You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. + tp_size (`Optional[int]`, *optional*): + The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache + if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the + number of key/value heads will not be adjusted. Example: @@ -1130,6 +1134,7 @@ class StaticCache(Cache): device: Union[torch.device, str, None] = None, dtype: torch.dtype = torch.float32, layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, + tp_size: Optional[int] = None, ) -> None: super().__init__() self.max_batch_size = max_batch_size @@ -1144,6 +1149,13 @@ class StaticCache(Cache): if getattr(config, "num_key_value_heads", None) is None else config.num_key_value_heads ) + if tp_size is not None and tp_size > 1: + if self.num_key_value_heads % tp_size != 0: + raise ValueError( + f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}." + ) + # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. + self.num_key_value_heads //= tp_size self.key_cache: list[torch.Tensor] = [] self.value_cache: list[torch.Tensor] = [] @@ -1573,6 +1585,10 @@ class HybridCache(Cache): Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is split between different gpus. You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. + tp_size (`Optional[int]`, *optional*): + The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache + if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the + number of key/value heads will not be adjusted. Example: @@ -1604,6 +1620,7 @@ class HybridCache(Cache): device: Union[torch.device, str, None] = None, dtype: torch.dtype = torch.float32, layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, + tp_size: Optional[int] = None, ) -> None: super().__init__() if not hasattr(config, "sliding_window") or config.sliding_window is None: @@ -1627,6 +1644,13 @@ class HybridCache(Cache): if getattr(config, "num_key_value_heads", None) is None else config.num_key_value_heads ) + if tp_size is not None and tp_size > 1: + if self.num_key_value_heads % tp_size != 0: + raise ValueError( + f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}." + ) + # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. + self.num_key_value_heads //= tp_size # If the attribute does not exist in the config, fallback to a simple StaticCache if hasattr(config, "layer_types"): @@ -2197,6 +2221,10 @@ class OffloadedStaticCache(StaticCache): Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is split between different gpus. You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. + tp_size (`Optional[int]`, *optional*): + The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache + if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the + number of key/value heads will not be adjusted. Example: @@ -2228,6 +2256,7 @@ class OffloadedStaticCache(StaticCache): dtype: Optional[torch.dtype] = None, offload_device: Union[str, torch.device] = torch.device("cpu"), layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, + tp_size: Optional[int] = None, ) -> None: super(Cache, self).__init__() @@ -2251,6 +2280,13 @@ class OffloadedStaticCache(StaticCache): if getattr(config, "num_key_value_heads", None) is None else config.num_key_value_heads ) + if tp_size is not None and tp_size > 1: + if num_key_value_heads % tp_size != 0: + raise ValueError( + f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}." + ) + # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. + num_key_value_heads //= tp_size cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c778e9e012..6208945434 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1963,6 +1963,9 @@ class GenerationMixin(ContinuousMixin): "device": device, "layer_device_map": layer_device_map, } + if cache_implementation in ["static", "hybrid", "offloaded_static"]: + cache_kwargs.update({"tp_size": self.tp_size}) + self._cache = cache_cls(**cache_kwargs) if requires_cross_attention_cache: encoder_kwargs = cache_kwargs.copy() diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 37d21f9fdf..79a2c294c6 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4494,6 +4494,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi raise ValueError("device_mesh must be 1 dimensional and will be used for TP") device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"])) + if tp_size is None: + tp_size = torch.distributed.get_world_size() + if use_auth_token is not None: warnings.warn( "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", diff --git a/tests/tensor_parallel/test_tensor_parallel.py b/tests/tensor_parallel/test_tensor_parallel.py index 69abd550e5..980d6fff8d 100644 --- a/tests/tensor_parallel/test_tensor_parallel.py +++ b/tests/tensor_parallel/test_tensor_parallel.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Run the test: CUDA_VISIBLE_DEVICES=0,1 RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py + import os import subprocess import tempfile @@ -62,7 +64,6 @@ class TestTensorParallelUtils(TestCasePlus): assert torch.allclose(unpacked_weights, original_packed_weights) -# RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py class TestTensorParallel(TestCasePlus): nproc_per_node = 2 @@ -125,6 +126,46 @@ class TestTensorParallel(TestCasePlus): ) self.torchrun(script_to_run) + def test_model_generate(self): + script_to_run = textwrap.dedent( + """ + import torch + import os + from transformers import AutoModelForCausalLM, AutoTokenizer + + model_id = "JackFram/llama-68m" + + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto") + torch.distributed.barrier() + + model.forward = torch.compile(model.forward) + + has_dtensor = 0 + for name, parameter in model.named_parameters(): + if isinstance(parameter.data, torch.distributed.tensor.DTensor): + has_dtensor = 1 + break + + assert has_dtensor == 1, "TP model must has DTensor" + + tokenizer = AutoTokenizer.from_pretrained(model_id) + prompt = "Can I help" + + inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) + outputs = model.generate(inputs, max_new_tokens=10, cache_implementation="static") + + output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True) + assert output_text[0].startswith(prompt), f"Expected output to start with '{prompt}', got '{output_text[0]}'" + + torch.distributed.barrier() + torch.distributed.destroy_process_group() + """ + ) + self.torchrun(script_to_run) + @require_huggingface_hub_greater_or_equal("0.31.4") def test_model_save(self): from safetensors import safe_open