From 8736e91ad602c5dc436b5403bb8c5249fbb13941 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 23 Jan 2025 09:44:21 +0100 Subject: [PATCH] [ `tests`] remove some flash attention class tests (#35817) remove class from tests --- tests/test_modeling_common.py | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index d361378503..bde6e07eff 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4628,43 +4628,12 @@ class ModelTesterMixin: dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device) - fa2_correctly_converted = False - - for _, module in fa2_model.named_modules(): - if "FlashAttention" in module.__class__.__name__: - fa2_correctly_converted = True - break - - fa2_correctly_converted = ( - fa2_correctly_converted - if not model_class._supports_flex_attn - else fa2_model.config._attn_implementation == "flash_attention_2" - ) - self.assertTrue(fa2_correctly_converted) - _ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask) - with tempfile.TemporaryDirectory() as tmpdirname: fa2_model.save_pretrained(tmpdirname) - model_from_pretrained = model_class.from_pretrained(tmpdirname) - self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2") - fa2_correctly_converted = False - - for _, module in model_from_pretrained.named_modules(): - if "FlashAttention" in module.__class__.__name__: - fa2_correctly_converted = True - break - - fa2_correctly_converted = ( - fa2_correctly_converted - if not model_class._supports_flex_attn - else model_from_pretrained.config._attn_implementation == "flash_attention_2" - ) - self.assertFalse(fa2_correctly_converted) - def _get_custom_4d_mask_test_data(self): # Sequence in which all but the last token is the same input_ids = torch.tensor(