[generate] torch.distributed-compatible DynamicCache (#36373)

* test

* docstring

* prepare distributed cache data

* fix cat dim

* test mvp

* add test checks

* like this?

* working test and solution

* nit

* nit

* add shape info
This commit is contained in:
Joao Gante
2025-02-27 11:48:57 +00:00
committed by GitHub
parent 17792556b2
commit 8aed019764
2 changed files with 47 additions and 2 deletions

View File

@@ -3,7 +3,7 @@ import importlib.metadata
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from packaging import version
@@ -358,12 +358,23 @@ class DynamicCache(Cache):
```
"""
def __init__(self) -> None:
def __init__(self, _distributed_cache_data: Iterable = None) -> None:
super().__init__()
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
# `_distributed_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36121
# and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the
# iterable contains the key and value states for a layer gathered across replicas by torch.distributed
# (shape=[global batch size, num_heads, seq_len, head_dim]).
# WARNING: `_distributed_cache_data` must be the first argument in `__init__`, otherwise we'll break
# compatibility. The name of the argument doesn't matter.
if _distributed_cache_data is not None:
for key_states, value_states in _distributed_cache_data:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the

View File

@@ -20,12 +20,14 @@ from parameterized import parameterized
from transformers import set_seed
from transformers.testing_utils import (
get_gpu_count,
is_torch_available,
require_gptq,
require_non_xpu,
require_read_token,
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
slow,
torch_device,
)
@@ -620,3 +622,35 @@ class CacheIntegrationTest(unittest.TestCase):
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
] # fmt: skip
self.assertEqual(responses, EXPECTED_DECODED_TEXT)
@require_torch_multi_gpu
def test_data_parallel_dynamic_cache(self):
"""
Tests that the dynamic cache works with nn.DataParallel. Under the hood, `DynamicCache` is rebuilt from
multiple `DynamicCache` in the gather step.
"""
model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM"
model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_repo)
# w/o DP: batch_size = num_gpu
# w DP: batch_size = 1 (with num_gpus replicas)
num_gpus = get_gpu_count()
model_inputs = tokenizer(["foo bar"] * num_gpus, return_tensors="pt").to(model.device)
# w/o DP
no_parallelism_cache = model(**model_inputs).past_key_values
self.assertIsInstance(no_parallelism_cache, DynamicCache)
# w DP
model = torch.nn.DataParallel(model)
parallelism_cache = model(**model_inputs).past_key_values
self.assertIsInstance(parallelism_cache, DynamicCache)
# Check that the caches are the same
for layer_idx in range(len(no_parallelism_cache)):
for kv_idx in range(2): # 0 = key, 1 = value
torch.testing.assert_close(
actual=parallelism_cache[layer_idx][kv_idx], expected=no_parallelism_cache[layer_idx][kv_idx]
)