Init cache on meta device (#35164)

* init cache on meta device

* offloaded static + enable tests

* tests weren't running before  :(

* update

* fix mamba

* fix copies

* update

* address comments and fix tests

* fix copies

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* update

* mamba fix

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2025-01-22 09:49:17 +01:00
committed by GitHub
parent 870e2c8ea0
commit 373e50e970
10 changed files with 111 additions and 111 deletions

View File

@@ -728,22 +728,13 @@ class LlamaIntegrationTest(unittest.TestCase):
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)
# Static Cache
# Static Cache + compile (`generate()` internally compiles each decoding step when static cache is used)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
# Static Cache + compile
model._cache = None # clear cache object, initialized when we pass `cache_implementation="static"`
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
@slow
@require_read_token
def test_export_static_cache(self):
@@ -795,6 +786,7 @@ class LlamaIntegrationTest(unittest.TestCase):
cache_config={
"batch_size": batch_size,
"max_cache_len": max_generation_length,
"device": device,
},
),
)