[ tests] remove some flash attention class tests (#35817)

remove class from tests
This commit is contained in:
Arthur
2025-01-23 09:44:21 +01:00
committed by GitHub
parent 2c3a44f9a7
commit 8736e91ad6

View File

@@ -4628,43 +4628,12 @@ class ModelTesterMixin:
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
fa2_correctly_converted = False
for _, module in fa2_model.named_modules():
if "FlashAttention" in module.__class__.__name__:
fa2_correctly_converted = True
break
fa2_correctly_converted = (
fa2_correctly_converted
if not model_class._supports_flex_attn
else fa2_model.config._attn_implementation == "flash_attention_2"
)
self.assertTrue(fa2_correctly_converted)
_ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
with tempfile.TemporaryDirectory() as tmpdirname:
fa2_model.save_pretrained(tmpdirname)
model_from_pretrained = model_class.from_pretrained(tmpdirname)
self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2")
fa2_correctly_converted = False
for _, module in model_from_pretrained.named_modules():
if "FlashAttention" in module.__class__.__name__:
fa2_correctly_converted = True
break
fa2_correctly_converted = (
fa2_correctly_converted
if not model_class._supports_flex_attn
else model_from_pretrained.config._attn_implementation == "flash_attention_2"
)
self.assertFalse(fa2_correctly_converted)
def _get_custom_4d_mask_test_data(self):
# Sequence in which all but the last token is the same
input_ids = torch.tensor(