Add a static cache that offloads to the CPU or other device (#32161)

* Add a static cache that offloads to the CPU or other device

* Fix PR comments, add unit-tests
This commit is contained in:
Gerben van V
2024-08-29 11:51:09 +02:00
committed by GitHub
parent 92a75ff6b1
commit 5129671290
7 changed files with 350 additions and 19 deletions

View File

@@ -380,8 +380,15 @@ class CacheIntegrationTest(unittest.TestCase):
self.assertTrue(decoded[0].endswith(last_output))
@require_torch_gpu
@parameterized.expand(["eager", "sdpa"])
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation):
@parameterized.expand(
[
("eager", "static"),
("sdpa", "static"),
("eager", "offloaded-static"),
("sdpa", "offloaded-static"),
]
)
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_implementation):
EXPECTED_GENERATION = [
"The best color is the one that complements the skin tone of the",
"We should not undermind the issues at hand.\nWe should not undermind the issues",
@@ -406,7 +413,7 @@ class CacheIntegrationTest(unittest.TestCase):
self.assertListEqual(decoded, EXPECTED_GENERATION)
set_seed(0)
model.generation_config.cache_implementation = "static"
model.generation_config.cache_implementation = cache_implementation
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, eager"):
@@ -420,8 +427,15 @@ class CacheIntegrationTest(unittest.TestCase):
self.assertListEqual(decoded, EXPECTED_GENERATION)
@require_torch_gpu
@parameterized.expand(["eager", "sdpa"])
def test_static_cache_greedy_decoding_pad_right(self, attn_implementation):
@parameterized.expand(
[
("eager", "static"),
("sdpa", "static"),
("eager", "offloaded-static"),
("sdpa", "offloaded-static"),
]
)
def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache_implementation):
EXPECTED_GENERATION = [
"The best color isЋ the one that complements the skin tone of",
"We should not undermind the issues at hand.\nWe should not undermind the issues",
@@ -446,7 +460,7 @@ class CacheIntegrationTest(unittest.TestCase):
self.assertListEqual(decoded, EXPECTED_GENERATION)
set_seed(0)
model.generation_config.cache_implementation = "static"
model.generation_config.cache_implementation = cache_implementation
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, eager"):
@@ -506,7 +520,13 @@ class CacheIntegrationTest(unittest.TestCase):
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
self.assertListEqual(decoded, EXPECTED_GENERATION)
def test_static_cache_extra_left_padding(self):
@parameterized.expand(
[
"static",
"offloaded-static",
]
)
def test_static_cache_extra_left_padding(self, cache_implementation):
"""Tests that adding extra left-padding does not affect the generation with the static cache"""
EXPECTED_GENERATION = [
"The best color is the one that complements the skin tone of the",
@@ -524,7 +544,7 @@ class CacheIntegrationTest(unittest.TestCase):
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
).to(model.device)
model.generation_config.cache_implementation = "static"
model.generation_config.cache_implementation = cache_implementation
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)