[ tests] remove some flash attention class tests (#35817)
remove class from tests
This commit is contained in:
@@ -4628,43 +4628,12 @@ class ModelTesterMixin:
|
|||||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
||||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
|
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
|
||||||
|
|
||||||
fa2_correctly_converted = False
|
|
||||||
|
|
||||||
for _, module in fa2_model.named_modules():
|
|
||||||
if "FlashAttention" in module.__class__.__name__:
|
|
||||||
fa2_correctly_converted = True
|
|
||||||
break
|
|
||||||
|
|
||||||
fa2_correctly_converted = (
|
|
||||||
fa2_correctly_converted
|
|
||||||
if not model_class._supports_flex_attn
|
|
||||||
else fa2_model.config._attn_implementation == "flash_attention_2"
|
|
||||||
)
|
|
||||||
self.assertTrue(fa2_correctly_converted)
|
|
||||||
|
|
||||||
_ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
|
_ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
fa2_model.save_pretrained(tmpdirname)
|
fa2_model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
model_from_pretrained = model_class.from_pretrained(tmpdirname)
|
model_from_pretrained = model_class.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2")
|
self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2")
|
||||||
|
|
||||||
fa2_correctly_converted = False
|
|
||||||
|
|
||||||
for _, module in model_from_pretrained.named_modules():
|
|
||||||
if "FlashAttention" in module.__class__.__name__:
|
|
||||||
fa2_correctly_converted = True
|
|
||||||
break
|
|
||||||
|
|
||||||
fa2_correctly_converted = (
|
|
||||||
fa2_correctly_converted
|
|
||||||
if not model_class._supports_flex_attn
|
|
||||||
else model_from_pretrained.config._attn_implementation == "flash_attention_2"
|
|
||||||
)
|
|
||||||
self.assertFalse(fa2_correctly_converted)
|
|
||||||
|
|
||||||
def _get_custom_4d_mask_test_data(self):
|
def _get_custom_4d_mask_test_data(self):
|
||||||
# Sequence in which all but the last token is the same
|
# Sequence in which all but the last token is the same
|
||||||
input_ids = torch.tensor(
|
input_ids = torch.tensor(
|
||||||
|
|||||||
Reference in New Issue
Block a user