[generate] torch.distributed-compatible DynamicCache (#36373)
* test * docstring * prepare distributed cache data * fix cat dim * test mvp * add test checks * like this? * working test and solution * nit * nit * add shape info
This commit is contained in:
@@ -3,7 +3,7 @@ import importlib.metadata
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
@@ -358,12 +358,23 @@ class DynamicCache(Cache):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, _distributed_cache_data: Iterable = None) -> None:
|
||||
super().__init__()
|
||||
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
||||
self.key_cache: List[torch.Tensor] = []
|
||||
self.value_cache: List[torch.Tensor] = []
|
||||
|
||||
# `_distributed_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36121
|
||||
# and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the
|
||||
# iterable contains the key and value states for a layer gathered across replicas by torch.distributed
|
||||
# (shape=[global batch size, num_heads, seq_len, head_dim]).
|
||||
# WARNING: `_distributed_cache_data` must be the first argument in `__init__`, otherwise we'll break
|
||||
# compatibility. The name of the argument doesn't matter.
|
||||
if _distributed_cache_data is not None:
|
||||
for key_states, value_states in _distributed_cache_data:
|
||||
self.key_cache.append(key_states)
|
||||
self.value_cache.append(value_states)
|
||||
|
||||
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
||||
"""
|
||||
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
|
||||
|
||||
@@ -20,12 +20,14 @@ from parameterized import parameterized
|
||||
|
||||
from transformers import set_seed
|
||||
from transformers.testing_utils import (
|
||||
get_gpu_count,
|
||||
is_torch_available,
|
||||
require_gptq,
|
||||
require_non_xpu,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -620,3 +622,35 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
|
||||
] # fmt: skip
|
||||
self.assertEqual(responses, EXPECTED_DECODED_TEXT)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_data_parallel_dynamic_cache(self):
|
||||
"""
|
||||
Tests that the dynamic cache works with nn.DataParallel. Under the hood, `DynamicCache` is rebuilt from
|
||||
multiple `DynamicCache` in the gather step.
|
||||
"""
|
||||
|
||||
model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
||||
|
||||
# w/o DP: batch_size = num_gpu
|
||||
# w DP: batch_size = 1 (with num_gpus replicas)
|
||||
num_gpus = get_gpu_count()
|
||||
model_inputs = tokenizer(["foo bar"] * num_gpus, return_tensors="pt").to(model.device)
|
||||
|
||||
# w/o DP
|
||||
no_parallelism_cache = model(**model_inputs).past_key_values
|
||||
self.assertIsInstance(no_parallelism_cache, DynamicCache)
|
||||
|
||||
# w DP
|
||||
model = torch.nn.DataParallel(model)
|
||||
parallelism_cache = model(**model_inputs).past_key_values
|
||||
self.assertIsInstance(parallelism_cache, DynamicCache)
|
||||
|
||||
# Check that the caches are the same
|
||||
for layer_idx in range(len(no_parallelism_cache)):
|
||||
for kv_idx in range(2): # 0 = key, 1 = value
|
||||
torch.testing.assert_close(
|
||||
actual=parallelism_cache[layer_idx][kv_idx], expected=no_parallelism_cache[layer_idx][kv_idx]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user