Fix FA2 tests (#29909)

* fix FA2 tests

* refactor inference test name
This commit is contained in:
Yoach Lacombe
2024-04-01 08:51:00 +01:00
committed by GitHub
parent 3b8e2932ce
commit 569f6c7d43
9 changed files with 15 additions and 19 deletions

View File

@@ -3245,7 +3245,7 @@ class ModelTesterMixin:
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference(self):
def test_flash_attn_2_inference_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
@@ -3260,9 +3260,7 @@ class ModelTesterMixin:
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
dummy_input = inputs_dict[model.main_input_name][:1]
@@ -3340,7 +3338,7 @@ class ModelTesterMixin:
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
@@ -3355,9 +3353,7 @@ class ModelTesterMixin:
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
dummy_input = inputs_dict[model.main_input_name][:1]