enable static cache on TP model (#39164)

* enable static cache on TP model

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* check tp size before init kv cache

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix docstring

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add tp tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix comment

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix other cache head size

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
jiqing-feng
2025-07-10 05:14:45 +08:00
committed by GitHub
parent 2ef59646b8
commit aff7df8436
4 changed files with 84 additions and 1 deletions

View File

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Run the test: CUDA_VISIBLE_DEVICES=0,1 RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
import os
import subprocess
import tempfile
@@ -62,7 +64,6 @@ class TestTensorParallelUtils(TestCasePlus):
assert torch.allclose(unpacked_weights, original_packed_weights)
# RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
class TestTensorParallel(TestCasePlus):
nproc_per_node = 2
@@ -125,6 +126,46 @@ class TestTensorParallel(TestCasePlus):
)
self.torchrun(script_to_run)
def test_model_generate(self):
script_to_run = textwrap.dedent(
"""
import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "JackFram/llama-68m"
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto")
torch.distributed.barrier()
model.forward = torch.compile(model.forward)
has_dtensor = 0
for name, parameter in model.named_parameters():
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
has_dtensor = 1
break
assert has_dtensor == 1, "TP model must has DTensor"
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
outputs = model.generate(inputs, max_new_tokens=10, cache_implementation="static")
output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
assert output_text[0].startswith(prompt), f"Expected output to start with '{prompt}', got '{output_text[0]}'"
torch.distributed.barrier()
torch.distributed.destroy_process_group()
"""
)
self.torchrun(script_to_run)
@require_huggingface_hub_greater_or_equal("0.31.4")
def test_model_save(self):
from safetensors import safe_open