enable static cache on TP model (#39164)

* enable static cache on TP model

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* check tp size before init kv cache

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix docstring

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add tp tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix comment

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix other cache head size

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
jiqing-feng
2025-07-10 05:14:45 +08:00
committed by GitHub
parent 2ef59646b8
commit aff7df8436
4 changed files with 84 additions and 1 deletions

View File

@@ -1098,6 +1098,10 @@ class StaticCache(Cache):
Mapping between the layers and its device. This is required when you are manually initializing the cache Mapping between the layers and its device. This is required when you are manually initializing the cache
and the model is split between different gpus. You can know which layers mapped to which device by and the model is split between different gpus. You can know which layers mapped to which device by
checking the associated device_map: `model.hf_device_map`. checking the associated device_map: `model.hf_device_map`.
tp_size (`Optional[int]`, *optional*):
The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache
if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the
number of key/value heads will not be adjusted.
Example: Example:
@@ -1130,6 +1134,7 @@ class StaticCache(Cache):
device: Union[torch.device, str, None] = None, device: Union[torch.device, str, None] = None,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
tp_size: Optional[int] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
@@ -1144,6 +1149,13 @@ class StaticCache(Cache):
if getattr(config, "num_key_value_heads", None) is None if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads else config.num_key_value_heads
) )
if tp_size is not None and tp_size > 1:
if self.num_key_value_heads % tp_size != 0:
raise ValueError(
f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
)
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
self.num_key_value_heads //= tp_size
self.key_cache: list[torch.Tensor] = [] self.key_cache: list[torch.Tensor] = []
self.value_cache: list[torch.Tensor] = [] self.value_cache: list[torch.Tensor] = []
@@ -1573,6 +1585,10 @@ class HybridCache(Cache):
Mapping between the layers and its device. This is required when you are manually initializing the cache Mapping between the layers and its device. This is required when you are manually initializing the cache
and the model is split between different gpus. You can know which layers mapped to which device by and the model is split between different gpus. You can know which layers mapped to which device by
checking the associated device_map: `model.hf_device_map`. checking the associated device_map: `model.hf_device_map`.
tp_size (`Optional[int]`, *optional*):
The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache
if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the
number of key/value heads will not be adjusted.
Example: Example:
@@ -1604,6 +1620,7 @@ class HybridCache(Cache):
device: Union[torch.device, str, None] = None, device: Union[torch.device, str, None] = None,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
tp_size: Optional[int] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None: if not hasattr(config, "sliding_window") or config.sliding_window is None:
@@ -1627,6 +1644,13 @@ class HybridCache(Cache):
if getattr(config, "num_key_value_heads", None) is None if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads else config.num_key_value_heads
) )
if tp_size is not None and tp_size > 1:
if self.num_key_value_heads % tp_size != 0:
raise ValueError(
f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
)
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
self.num_key_value_heads //= tp_size
# If the attribute does not exist in the config, fallback to a simple StaticCache # If the attribute does not exist in the config, fallback to a simple StaticCache
if hasattr(config, "layer_types"): if hasattr(config, "layer_types"):
@@ -2197,6 +2221,10 @@ class OffloadedStaticCache(StaticCache):
Mapping between the layers and its device. This is required when you are manually initializing the cache Mapping between the layers and its device. This is required when you are manually initializing the cache
and the model is split between different gpus. You can know which layers mapped to which device by and the model is split between different gpus. You can know which layers mapped to which device by
checking the associated device_map: `model.hf_device_map`. checking the associated device_map: `model.hf_device_map`.
tp_size (`Optional[int]`, *optional*):
The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache
if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the
number of key/value heads will not be adjusted.
Example: Example:
@@ -2228,6 +2256,7 @@ class OffloadedStaticCache(StaticCache):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
offload_device: Union[str, torch.device] = torch.device("cpu"), offload_device: Union[str, torch.device] = torch.device("cpu"),
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
tp_size: Optional[int] = None,
) -> None: ) -> None:
super(Cache, self).__init__() super(Cache, self).__init__()
@@ -2251,6 +2280,13 @@ class OffloadedStaticCache(StaticCache):
if getattr(config, "num_key_value_heads", None) is None if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads else config.num_key_value_heads
) )
if tp_size is not None and tp_size > 1:
if num_key_value_heads % tp_size != 0:
raise ValueError(
f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
)
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
num_key_value_heads //= tp_size
cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim) cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim)

View File

@@ -1963,6 +1963,9 @@ class GenerationMixin(ContinuousMixin):
"device": device, "device": device,
"layer_device_map": layer_device_map, "layer_device_map": layer_device_map,
} }
if cache_implementation in ["static", "hybrid", "offloaded_static"]:
cache_kwargs.update({"tp_size": self.tp_size})
self._cache = cache_cls(**cache_kwargs) self._cache = cache_cls(**cache_kwargs)
if requires_cross_attention_cache: if requires_cross_attention_cache:
encoder_kwargs = cache_kwargs.copy() encoder_kwargs = cache_kwargs.copy()

View File

@@ -4494,6 +4494,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
raise ValueError("device_mesh must be 1 dimensional and will be used for TP") raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"])) device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
if tp_size is None:
tp_size = torch.distributed.get_world_size()
if use_auth_token is not None: if use_auth_token is not None:
warnings.warn( warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",

View File

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Run the test: CUDA_VISIBLE_DEVICES=0,1 RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
import os import os
import subprocess import subprocess
import tempfile import tempfile
@@ -62,7 +64,6 @@ class TestTensorParallelUtils(TestCasePlus):
assert torch.allclose(unpacked_weights, original_packed_weights) assert torch.allclose(unpacked_weights, original_packed_weights)
# RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
class TestTensorParallel(TestCasePlus): class TestTensorParallel(TestCasePlus):
nproc_per_node = 2 nproc_per_node = 2
@@ -125,6 +126,46 @@ class TestTensorParallel(TestCasePlus):
) )
self.torchrun(script_to_run) self.torchrun(script_to_run)
def test_model_generate(self):
script_to_run = textwrap.dedent(
"""
import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "JackFram/llama-68m"
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto")
torch.distributed.barrier()
model.forward = torch.compile(model.forward)
has_dtensor = 0
for name, parameter in model.named_parameters():
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
has_dtensor = 1
break
assert has_dtensor == 1, "TP model must has DTensor"
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
outputs = model.generate(inputs, max_new_tokens=10, cache_implementation="static")
output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
assert output_text[0].startswith(prompt), f"Expected output to start with '{prompt}', got '{output_text[0]}'"
torch.distributed.barrier()
torch.distributed.destroy_process_group()
"""
)
self.torchrun(script_to_run)
@require_huggingface_hub_greater_or_equal("0.31.4") @require_huggingface_hub_greater_or_equal("0.31.4")
def test_model_save(self): def test_model_save(self):
from safetensors import safe_open from safetensors import safe_open