Add a static cache that offloads to the CPU or other device (#32161)
* Add a static cache that offloads to the CPU or other device * Fix PR comments, add unit-tests
This commit is contained in:
@@ -380,8 +380,15 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(decoded[0].endswith(last_output))
|
||||
|
||||
@require_torch_gpu
|
||||
@parameterized.expand(["eager", "sdpa"])
|
||||
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation):
|
||||
@parameterized.expand(
|
||||
[
|
||||
("eager", "static"),
|
||||
("sdpa", "static"),
|
||||
("eager", "offloaded-static"),
|
||||
("sdpa", "offloaded-static"),
|
||||
]
|
||||
)
|
||||
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_implementation):
|
||||
EXPECTED_GENERATION = [
|
||||
"The best color is the one that complements the skin tone of the",
|
||||
"We should not undermind the issues at hand.\nWe should not undermind the issues",
|
||||
@@ -406,7 +413,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
set_seed(0)
|
||||
model.generation_config.cache_implementation = "static"
|
||||
model.generation_config.cache_implementation = cache_implementation
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, static, eager"):
|
||||
@@ -420,8 +427,15 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
@require_torch_gpu
|
||||
@parameterized.expand(["eager", "sdpa"])
|
||||
def test_static_cache_greedy_decoding_pad_right(self, attn_implementation):
|
||||
@parameterized.expand(
|
||||
[
|
||||
("eager", "static"),
|
||||
("sdpa", "static"),
|
||||
("eager", "offloaded-static"),
|
||||
("sdpa", "offloaded-static"),
|
||||
]
|
||||
)
|
||||
def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache_implementation):
|
||||
EXPECTED_GENERATION = [
|
||||
"The best color isЋ the one that complements the skin tone of",
|
||||
"We should not undermind the issues at hand.\nWe should not undermind the issues",
|
||||
@@ -446,7 +460,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
set_seed(0)
|
||||
model.generation_config.cache_implementation = "static"
|
||||
model.generation_config.cache_implementation = cache_implementation
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, static, eager"):
|
||||
@@ -506,7 +520,13 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
def test_static_cache_extra_left_padding(self):
|
||||
@parameterized.expand(
|
||||
[
|
||||
"static",
|
||||
"offloaded-static",
|
||||
]
|
||||
)
|
||||
def test_static_cache_extra_left_padding(self, cache_implementation):
|
||||
"""Tests that adding extra left-padding does not affect the generation with the static cache"""
|
||||
EXPECTED_GENERATION = [
|
||||
"The best color is the one that complements the skin tone of the",
|
||||
@@ -524,7 +544,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
model.generation_config.cache_implementation = "static"
|
||||
model.generation_config.cache_implementation = cache_implementation
|
||||
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
|
||||
Reference in New Issue
Block a user