Init cache on meta device (#35164)

* init cache on meta device

* offloaded static + enable tests

* tests weren't running before  :(

* update

* fix mamba

* fix copies

* update

* address comments and fix tests

* fix copies

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* update

* mamba fix

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2025-01-22 09:49:17 +01:00
committed by GitHub
parent 870e2c8ea0
commit 373e50e970
10 changed files with 111 additions and 111 deletions

View File

@@ -4635,6 +4635,11 @@ class ModelTesterMixin:
fa2_correctly_converted = True
break
fa2_correctly_converted = (
fa2_correctly_converted
if not model_class._supports_flex_attn
else fa2_model.config._attn_implementation == "flash_attention_2"
)
self.assertTrue(fa2_correctly_converted)
_ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
@@ -4653,6 +4658,11 @@ class ModelTesterMixin:
fa2_correctly_converted = True
break
fa2_correctly_converted = (
fa2_correctly_converted
if not model_class._supports_flex_attn
else model_from_pretrained.config._attn_implementation == "flash_attention_2"
)
self.assertFalse(fa2_correctly_converted)
def _get_custom_4d_mask_test_data(self):