From 8aed0197642c07508c8c85dbe1743a90dad5bb22 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 27 Feb 2025 11:48:57 +0000 Subject: [PATCH] [generate] `torch.distributed`-compatible `DynamicCache` (#36373) * test * docstring * prepare distributed cache data * fix cat dim * test mvp * add test checks * like this? * working test and solution * nit * nit * add shape info --- src/transformers/cache_utils.py | 15 +++++++++++++-- tests/utils/test_cache_utils.py | 34 +++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index dd12aa5751..87a6a8fb1a 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -3,7 +3,7 @@ import importlib.metadata import json import os from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from packaging import version @@ -358,12 +358,23 @@ class DynamicCache(Cache): ``` """ - def __init__(self) -> None: + def __init__(self, _distributed_cache_data: Iterable = None) -> None: super().__init__() self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] + # `_distributed_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36121 + # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the + # iterable contains the key and value states for a layer gathered across replicas by torch.distributed + # (shape=[global batch size, num_heads, seq_len, head_dim]). + # WARNING: `_distributed_cache_data` must be the first argument in `__init__`, otherwise we'll break + # compatibility. The name of the argument doesn't matter. + if _distributed_cache_data is not None: + for key_states, value_states in _distributed_cache_data: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: """ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index a8b8b1eff2..e16d30e549 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -20,12 +20,14 @@ from parameterized import parameterized from transformers import set_seed from transformers.testing_utils import ( + get_gpu_count, is_torch_available, require_gptq, require_non_xpu, require_read_token, require_torch, require_torch_gpu, + require_torch_multi_gpu, slow, torch_device, ) @@ -620,3 +622,35 @@ class CacheIntegrationTest(unittest.TestCase): 'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the' ] # fmt: skip self.assertEqual(responses, EXPECTED_DECODED_TEXT) + + @require_torch_multi_gpu + def test_data_parallel_dynamic_cache(self): + """ + Tests that the dynamic cache works with nn.DataParallel. Under the hood, `DynamicCache` is rebuilt from + multiple `DynamicCache` in the gather step. + """ + + model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM" + model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(model_repo) + + # w/o DP: batch_size = num_gpu + # w DP: batch_size = 1 (with num_gpus replicas) + num_gpus = get_gpu_count() + model_inputs = tokenizer(["foo bar"] * num_gpus, return_tensors="pt").to(model.device) + + # w/o DP + no_parallelism_cache = model(**model_inputs).past_key_values + self.assertIsInstance(no_parallelism_cache, DynamicCache) + + # w DP + model = torch.nn.DataParallel(model) + parallelism_cache = model(**model_inputs).past_key_values + self.assertIsInstance(parallelism_cache, DynamicCache) + + # Check that the caches are the same + for layer_idx in range(len(no_parallelism_cache)): + for kv_idx in range(2): # 0 = key, 1 = value + torch.testing.assert_close( + actual=parallelism_cache[layer_idx][kv_idx], expected=no_parallelism_cache[layer_idx][kv_idx] + )