Init cache on meta device (#35164)
* init cache on meta device * offloaded static + enable tests * tests weren't running before :( * update * fix mamba * fix copies * update * address comments and fix tests * fix copies * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update * mamba fix --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
870e2c8ea0
commit
373e50e970
@@ -728,22 +728,13 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)
|
||||
|
||||
# Static Cache
|
||||
# Static Cache + compile (`generate()` internally compiles each decoding step when static cache is used)
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||
)
|
||||
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
|
||||
|
||||
# Static Cache + compile
|
||||
model._cache = None # clear cache object, initialized when we pass `cache_implementation="static"`
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||
)
|
||||
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_export_static_cache(self):
|
||||
@@ -795,6 +786,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
cache_config={
|
||||
"batch_size": batch_size,
|
||||
"max_cache_len": max_generation_length,
|
||||
"device": device,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user