enable static cache on TP model (#39164)

* enable static cache on TP model

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* check tp size before init kv cache

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix docstring

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add tp tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix comment

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix other cache head size

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
jiqing-feng
2025-07-10 05:14:45 +08:00
committed by GitHub
parent 2ef59646b8
commit aff7df8436
4 changed files with 84 additions and 1 deletions

View File

@@ -1098,6 +1098,10 @@ class StaticCache(Cache):
Mapping between the layers and its device. This is required when you are manually initializing the cache
and the model is split between different gpus. You can know which layers mapped to which device by
checking the associated device_map: `model.hf_device_map`.
tp_size (`Optional[int]`, *optional*):
The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache
if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the
number of key/value heads will not be adjusted.
Example:
@@ -1130,6 +1134,7 @@ class StaticCache(Cache):
device: Union[torch.device, str, None] = None,
dtype: torch.dtype = torch.float32,
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
tp_size: Optional[int] = None,
) -> None:
super().__init__()
self.max_batch_size = max_batch_size
@@ -1144,6 +1149,13 @@ class StaticCache(Cache):
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)
if tp_size is not None and tp_size > 1:
if self.num_key_value_heads % tp_size != 0:
raise ValueError(
f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
)
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
self.num_key_value_heads //= tp_size
self.key_cache: list[torch.Tensor] = []
self.value_cache: list[torch.Tensor] = []
@@ -1573,6 +1585,10 @@ class HybridCache(Cache):
Mapping between the layers and its device. This is required when you are manually initializing the cache
and the model is split between different gpus. You can know which layers mapped to which device by
checking the associated device_map: `model.hf_device_map`.
tp_size (`Optional[int]`, *optional*):
The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache
if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the
number of key/value heads will not be adjusted.
Example:
@@ -1604,6 +1620,7 @@ class HybridCache(Cache):
device: Union[torch.device, str, None] = None,
dtype: torch.dtype = torch.float32,
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
tp_size: Optional[int] = None,
) -> None:
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
@@ -1627,6 +1644,13 @@ class HybridCache(Cache):
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)
if tp_size is not None and tp_size > 1:
if self.num_key_value_heads % tp_size != 0:
raise ValueError(
f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
)
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
self.num_key_value_heads //= tp_size
# If the attribute does not exist in the config, fallback to a simple StaticCache
if hasattr(config, "layer_types"):
@@ -2197,6 +2221,10 @@ class OffloadedStaticCache(StaticCache):
Mapping between the layers and its device. This is required when you are manually initializing the cache
and the model is split between different gpus. You can know which layers mapped to which device by
checking the associated device_map: `model.hf_device_map`.
tp_size (`Optional[int]`, *optional*):
The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache
if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the
number of key/value heads will not be adjusted.
Example:
@@ -2228,6 +2256,7 @@ class OffloadedStaticCache(StaticCache):
dtype: Optional[torch.dtype] = None,
offload_device: Union[str, torch.device] = torch.device("cpu"),
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
tp_size: Optional[int] = None,
) -> None:
super(Cache, self).__init__()
@@ -2251,6 +2280,13 @@ class OffloadedStaticCache(StaticCache):
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)
if tp_size is not None and tp_size > 1:
if num_key_value_heads % tp_size != 0:
raise ValueError(
f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
)
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
num_key_value_heads //= tp_size
cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim)

View File

@@ -1963,6 +1963,9 @@ class GenerationMixin(ContinuousMixin):
"device": device,
"layer_device_map": layer_device_map,
}
if cache_implementation in ["static", "hybrid", "offloaded_static"]:
cache_kwargs.update({"tp_size": self.tp_size})
self._cache = cache_cls(**cache_kwargs)
if requires_cross_attention_cache:
encoder_kwargs = cache_kwargs.copy()

View File

@@ -4494,6 +4494,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
if tp_size is None:
tp_size = torch.distributed.get_world_size()
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",

View File

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Run the test: CUDA_VISIBLE_DEVICES=0,1 RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
import os
import subprocess
import tempfile
@@ -62,7 +64,6 @@ class TestTensorParallelUtils(TestCasePlus):
assert torch.allclose(unpacked_weights, original_packed_weights)
# RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
class TestTensorParallel(TestCasePlus):
nproc_per_node = 2
@@ -125,6 +126,46 @@ class TestTensorParallel(TestCasePlus):
)
self.torchrun(script_to_run)
def test_model_generate(self):
script_to_run = textwrap.dedent(
"""
import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "JackFram/llama-68m"
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto")
torch.distributed.barrier()
model.forward = torch.compile(model.forward)
has_dtensor = 0
for name, parameter in model.named_parameters():
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
has_dtensor = 1
break
assert has_dtensor == 1, "TP model must has DTensor"
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
outputs = model.generate(inputs, max_new_tokens=10, cache_implementation="static")
output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
assert output_text[0].startswith(prompt), f"Expected output to start with '{prompt}', got '{output_text[0]}'"
torch.distributed.barrier()
torch.distributed.destroy_process_group()
"""
)
self.torchrun(script_to_run)
@require_huggingface_hub_greater_or_equal("0.31.4")
def test_model_save(self):
from safetensors import safe_open