[ tests] remove some flash attention class tests (#35817)
remove class from tests
This commit is contained in:
@@ -4628,43 +4628,12 @@ class ModelTesterMixin:
|
||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
|
||||
|
||||
fa2_correctly_converted = False
|
||||
|
||||
for _, module in fa2_model.named_modules():
|
||||
if "FlashAttention" in module.__class__.__name__:
|
||||
fa2_correctly_converted = True
|
||||
break
|
||||
|
||||
fa2_correctly_converted = (
|
||||
fa2_correctly_converted
|
||||
if not model_class._supports_flex_attn
|
||||
else fa2_model.config._attn_implementation == "flash_attention_2"
|
||||
)
|
||||
self.assertTrue(fa2_correctly_converted)
|
||||
|
||||
_ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fa2_model.save_pretrained(tmpdirname)
|
||||
|
||||
model_from_pretrained = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2")
|
||||
|
||||
fa2_correctly_converted = False
|
||||
|
||||
for _, module in model_from_pretrained.named_modules():
|
||||
if "FlashAttention" in module.__class__.__name__:
|
||||
fa2_correctly_converted = True
|
||||
break
|
||||
|
||||
fa2_correctly_converted = (
|
||||
fa2_correctly_converted
|
||||
if not model_class._supports_flex_attn
|
||||
else model_from_pretrained.config._attn_implementation == "flash_attention_2"
|
||||
)
|
||||
self.assertFalse(fa2_correctly_converted)
|
||||
|
||||
def _get_custom_4d_mask_test_data(self):
|
||||
# Sequence in which all but the last token is the same
|
||||
input_ids = torch.tensor(
|
||||
|
||||
Reference in New Issue
Block a user