[generate] torch.distributed-compatible DynamicCache (#36373)
* test * docstring * prepare distributed cache data * fix cat dim * test mvp * add test checks * like this? * working test and solution * nit * nit * add shape info
This commit is contained in:
@@ -3,7 +3,7 @@ import importlib.metadata
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@@ -358,12 +358,23 @@ class DynamicCache(Cache):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, _distributed_cache_data: Iterable = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
||||||
self.key_cache: List[torch.Tensor] = []
|
self.key_cache: List[torch.Tensor] = []
|
||||||
self.value_cache: List[torch.Tensor] = []
|
self.value_cache: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
# `_distributed_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36121
|
||||||
|
# and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the
|
||||||
|
# iterable contains the key and value states for a layer gathered across replicas by torch.distributed
|
||||||
|
# (shape=[global batch size, num_heads, seq_len, head_dim]).
|
||||||
|
# WARNING: `_distributed_cache_data` must be the first argument in `__init__`, otherwise we'll break
|
||||||
|
# compatibility. The name of the argument doesn't matter.
|
||||||
|
if _distributed_cache_data is not None:
|
||||||
|
for key_states, value_states in _distributed_cache_data:
|
||||||
|
self.key_cache.append(key_states)
|
||||||
|
self.value_cache.append(value_states)
|
||||||
|
|
||||||
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
|
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
|
||||||
|
|||||||
@@ -20,12 +20,14 @@ from parameterized import parameterized
|
|||||||
|
|
||||||
from transformers import set_seed
|
from transformers import set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
get_gpu_count,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
require_gptq,
|
require_gptq,
|
||||||
require_non_xpu,
|
require_non_xpu,
|
||||||
require_read_token,
|
require_read_token,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
|
require_torch_multi_gpu,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
@@ -620,3 +622,35 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
|
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
|
||||||
] # fmt: skip
|
] # fmt: skip
|
||||||
self.assertEqual(responses, EXPECTED_DECODED_TEXT)
|
self.assertEqual(responses, EXPECTED_DECODED_TEXT)
|
||||||
|
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
def test_data_parallel_dynamic_cache(self):
|
||||||
|
"""
|
||||||
|
Tests that the dynamic cache works with nn.DataParallel. Under the hood, `DynamicCache` is rebuilt from
|
||||||
|
multiple `DynamicCache` in the gather step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM"
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
||||||
|
|
||||||
|
# w/o DP: batch_size = num_gpu
|
||||||
|
# w DP: batch_size = 1 (with num_gpus replicas)
|
||||||
|
num_gpus = get_gpu_count()
|
||||||
|
model_inputs = tokenizer(["foo bar"] * num_gpus, return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
# w/o DP
|
||||||
|
no_parallelism_cache = model(**model_inputs).past_key_values
|
||||||
|
self.assertIsInstance(no_parallelism_cache, DynamicCache)
|
||||||
|
|
||||||
|
# w DP
|
||||||
|
model = torch.nn.DataParallel(model)
|
||||||
|
parallelism_cache = model(**model_inputs).past_key_values
|
||||||
|
self.assertIsInstance(parallelism_cache, DynamicCache)
|
||||||
|
|
||||||
|
# Check that the caches are the same
|
||||||
|
for layer_idx in range(len(no_parallelism_cache)):
|
||||||
|
for kv_idx in range(2): # 0 = key, 1 = value
|
||||||
|
torch.testing.assert_close(
|
||||||
|
actual=parallelism_cache[layer_idx][kv_idx], expected=no_parallelism_cache[layer_idx][kv_idx]
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user