Init cache on meta device (#35164)
* init cache on meta device * offloaded static + enable tests * tests weren't running before :( * update * fix mamba * fix copies * update * address comments and fix tests * fix copies * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update * mamba fix --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
870e2c8ea0
commit
373e50e970
@@ -4635,6 +4635,11 @@ class ModelTesterMixin:
|
||||
fa2_correctly_converted = True
|
||||
break
|
||||
|
||||
fa2_correctly_converted = (
|
||||
fa2_correctly_converted
|
||||
if not model_class._supports_flex_attn
|
||||
else fa2_model.config._attn_implementation == "flash_attention_2"
|
||||
)
|
||||
self.assertTrue(fa2_correctly_converted)
|
||||
|
||||
_ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
|
||||
@@ -4653,6 +4658,11 @@ class ModelTesterMixin:
|
||||
fa2_correctly_converted = True
|
||||
break
|
||||
|
||||
fa2_correctly_converted = (
|
||||
fa2_correctly_converted
|
||||
if not model_class._supports_flex_attn
|
||||
else model_from_pretrained.config._attn_implementation == "flash_attention_2"
|
||||
)
|
||||
self.assertFalse(fa2_correctly_converted)
|
||||
|
||||
def _get_custom_4d_mask_test_data(self):
|
||||
|
||||
Reference in New Issue
Block a user