Init cache on meta device (#35164)
* init cache on meta device * offloaded static + enable tests * tests weren't running before :( * update * fix mamba * fix copies * update * address comments and fix tests * fix copies * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update * mamba fix --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
870e2c8ea0
commit
373e50e970
@@ -728,22 +728,13 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)
|
||||
|
||||
# Static Cache
|
||||
# Static Cache + compile (`generate()` internally compiles each decoding step when static cache is used)
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||
)
|
||||
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
|
||||
|
||||
# Static Cache + compile
|
||||
model._cache = None # clear cache object, initialized when we pass `cache_implementation="static"`
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||
)
|
||||
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_export_static_cache(self):
|
||||
@@ -795,6 +786,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
cache_config={
|
||||
"batch_size": batch_size,
|
||||
"max_cache_len": max_generation_length,
|
||||
"device": device,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
@@ -4635,6 +4635,11 @@ class ModelTesterMixin:
|
||||
fa2_correctly_converted = True
|
||||
break
|
||||
|
||||
fa2_correctly_converted = (
|
||||
fa2_correctly_converted
|
||||
if not model_class._supports_flex_attn
|
||||
else fa2_model.config._attn_implementation == "flash_attention_2"
|
||||
)
|
||||
self.assertTrue(fa2_correctly_converted)
|
||||
|
||||
_ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
|
||||
@@ -4653,6 +4658,11 @@ class ModelTesterMixin:
|
||||
fa2_correctly_converted = True
|
||||
break
|
||||
|
||||
fa2_correctly_converted = (
|
||||
fa2_correctly_converted
|
||||
if not model_class._supports_flex_attn
|
||||
else model_from_pretrained.config._attn_implementation == "flash_attention_2"
|
||||
)
|
||||
self.assertFalse(fa2_correctly_converted)
|
||||
|
||||
def _get_custom_4d_mask_test_data(self):
|
||||
|
||||
@@ -198,6 +198,7 @@ class CacheTest(unittest.TestCase):
|
||||
cache_config={
|
||||
"batch_size": batch_size,
|
||||
"max_cache_len": max_cache_len,
|
||||
"device": device,
|
||||
},
|
||||
),
|
||||
)
|
||||
@@ -310,11 +311,12 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
do_sample=False,
|
||||
max_new_tokens=20,
|
||||
num_return_sequences=2,
|
||||
num_beams=2,
|
||||
)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
expected_text = [
|
||||
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
|
||||
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
|
||||
"Hello I am doing a project for my school and I am trying to make a program that will allow me to input a",
|
||||
"Hello I am doing a project for my school and I am trying to make a program that will allow me to use a",
|
||||
]
|
||||
self.assertListEqual(decoded, expected_text)
|
||||
|
||||
@@ -380,8 +382,6 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
[
|
||||
("eager", "static"),
|
||||
("sdpa", "static"),
|
||||
("eager", "offloaded-static"),
|
||||
("sdpa", "offloaded-static"),
|
||||
]
|
||||
)
|
||||
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_implementation):
|
||||
@@ -427,8 +427,6 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
[
|
||||
("eager", "static"),
|
||||
("sdpa", "static"),
|
||||
("eager", "offloaded-static"),
|
||||
("sdpa", "offloaded-static"),
|
||||
]
|
||||
)
|
||||
def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache_implementation):
|
||||
@@ -462,26 +460,6 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
with self.subTest(f"{attn_implementation}, static, eager"):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
set_seed(0)
|
||||
model._forward = model.forward
|
||||
compiled_forward = torch.compile(model.forward)
|
||||
|
||||
def compiled(func, input_ids, **kwargs):
|
||||
return func(input_ids, **kwargs)
|
||||
|
||||
def call(input_ids, **kwargs):
|
||||
if input_ids.shape[-1] == 1:
|
||||
return compiled(compiled_forward, input_ids, **kwargs)
|
||||
|
||||
return model._forward(input_ids, **kwargs)
|
||||
|
||||
model.forward = call
|
||||
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, static, compiled"):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
def test_dynamic_cache_extra_left_padding(self):
|
||||
"""Tests that adding extra left-padding does not affect the generation with the dynamic cache"""
|
||||
EXPECTED_GENERATION = [
|
||||
@@ -519,7 +497,6 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
@parameterized.expand(
|
||||
[
|
||||
"static",
|
||||
"offloaded-static",
|
||||
]
|
||||
)
|
||||
def test_static_cache_extra_left_padding(self, cache_implementation):
|
||||
|
||||
Reference in New Issue
Block a user