[generate] torch.distributed-compatible DynamicCache (#36373)
* test * docstring * prepare distributed cache data * fix cat dim * test mvp * add test checks * like this? * working test and solution * nit * nit * add shape info
This commit is contained in:
@@ -20,12 +20,14 @@ from parameterized import parameterized
|
||||
|
||||
from transformers import set_seed
|
||||
from transformers.testing_utils import (
|
||||
get_gpu_count,
|
||||
is_torch_available,
|
||||
require_gptq,
|
||||
require_non_xpu,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -620,3 +622,35 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
|
||||
] # fmt: skip
|
||||
self.assertEqual(responses, EXPECTED_DECODED_TEXT)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_data_parallel_dynamic_cache(self):
|
||||
"""
|
||||
Tests that the dynamic cache works with nn.DataParallel. Under the hood, `DynamicCache` is rebuilt from
|
||||
multiple `DynamicCache` in the gather step.
|
||||
"""
|
||||
|
||||
model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
||||
|
||||
# w/o DP: batch_size = num_gpu
|
||||
# w DP: batch_size = 1 (with num_gpus replicas)
|
||||
num_gpus = get_gpu_count()
|
||||
model_inputs = tokenizer(["foo bar"] * num_gpus, return_tensors="pt").to(model.device)
|
||||
|
||||
# w/o DP
|
||||
no_parallelism_cache = model(**model_inputs).past_key_values
|
||||
self.assertIsInstance(no_parallelism_cache, DynamicCache)
|
||||
|
||||
# w DP
|
||||
model = torch.nn.DataParallel(model)
|
||||
parallelism_cache = model(**model_inputs).past_key_values
|
||||
self.assertIsInstance(parallelism_cache, DynamicCache)
|
||||
|
||||
# Check that the caches are the same
|
||||
for layer_idx in range(len(no_parallelism_cache)):
|
||||
for kv_idx in range(2): # 0 = key, 1 = value
|
||||
torch.testing.assert_close(
|
||||
actual=parallelism_cache[layer_idx][kv_idx], expected=no_parallelism_cache[layer_idx][kv_idx]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user