Init cache on meta device (#35164)

* init cache on meta device

* offloaded static + enable tests

* tests weren't running before  :(

* update

* fix mamba

* fix copies

* update

* address comments and fix tests

* fix copies

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* update

* mamba fix

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2025-01-22 09:49:17 +01:00
committed by GitHub
parent 870e2c8ea0
commit 373e50e970
10 changed files with 111 additions and 111 deletions

View File

@@ -198,6 +198,7 @@ class CacheTest(unittest.TestCase):
cache_config={
"batch_size": batch_size,
"max_cache_len": max_cache_len,
"device": device,
},
),
)
@@ -310,11 +311,12 @@ class CacheIntegrationTest(unittest.TestCase):
do_sample=False,
max_new_tokens=20,
num_return_sequences=2,
num_beams=2,
)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
expected_text = [
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
"Hello I am doing a project for my school and I am trying to make a program that will allow me to input a",
"Hello I am doing a project for my school and I am trying to make a program that will allow me to use a",
]
self.assertListEqual(decoded, expected_text)
@@ -380,8 +382,6 @@ class CacheIntegrationTest(unittest.TestCase):
[
("eager", "static"),
("sdpa", "static"),
("eager", "offloaded-static"),
("sdpa", "offloaded-static"),
]
)
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_implementation):
@@ -427,8 +427,6 @@ class CacheIntegrationTest(unittest.TestCase):
[
("eager", "static"),
("sdpa", "static"),
("eager", "offloaded-static"),
("sdpa", "offloaded-static"),
]
)
def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache_implementation):
@@ -462,26 +460,6 @@ class CacheIntegrationTest(unittest.TestCase):
with self.subTest(f"{attn_implementation}, static, eager"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
set_seed(0)
model._forward = model.forward
compiled_forward = torch.compile(model.forward)
def compiled(func, input_ids, **kwargs):
return func(input_ids, **kwargs)
def call(input_ids, **kwargs):
if input_ids.shape[-1] == 1:
return compiled(compiled_forward, input_ids, **kwargs)
return model._forward(input_ids, **kwargs)
model.forward = call
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, compiled"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
def test_dynamic_cache_extra_left_padding(self):
"""Tests that adding extra left-padding does not affect the generation with the dynamic cache"""
EXPECTED_GENERATION = [
@@ -519,7 +497,6 @@ class CacheIntegrationTest(unittest.TestCase):
@parameterized.expand(
[
"static",
"offloaded-static",
]
)
def test_static_cache_extra_left_padding(self, cache_implementation):