@@ -879,7 +879,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_inference(self):
|
def test_flash_attn_2_inference_equivalence(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn_2:
|
||||||
return
|
return
|
||||||
@@ -936,7 +936,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_inference_padding_right(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn_2:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -301,7 +301,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_inference(self):
|
def test_flash_attn_2_inference_equivalence(self):
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
@@ -353,7 +353,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_inference_padding_right(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
|
|||||||
@@ -462,7 +462,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_inference_padding_right(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
self.skipTest("Gemma flash attention does not support right padding")
|
self.skipTest("Gemma flash attention does not support right padding")
|
||||||
|
|
||||||
@require_torch_sdpa
|
@require_torch_sdpa
|
||||||
|
|||||||
@@ -466,7 +466,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_inference_padding_right(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
self.skipTest("Mistral flash attention does not support right padding")
|
self.skipTest("Mistral flash attention does not support right padding")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -465,7 +465,7 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_inference_padding_right(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
self.skipTest("Mixtral flash attention does not support right padding")
|
self.skipTest("Mixtral flash attention does not support right padding")
|
||||||
|
|
||||||
# Ignore copy
|
# Ignore copy
|
||||||
|
|||||||
@@ -477,7 +477,7 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_inference_padding_right(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
self.skipTest("Qwen2 flash attention does not support right padding")
|
self.skipTest("Qwen2 flash attention does not support right padding")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -461,7 +461,7 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_inference_padding_right(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
self.skipTest("Starcoder2 flash attention does not support right padding")
|
self.skipTest("Starcoder2 flash attention does not support right padding")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -888,7 +888,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_inference(self):
|
def test_flash_attn_2_inference_equivalence(self):
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
@@ -934,7 +934,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_inference_padding_right(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
|
|||||||
@@ -3245,7 +3245,7 @@ class ModelTesterMixin:
|
|||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@mark.flash_attn_test
|
@mark.flash_attn_test
|
||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_inference(self):
|
def test_flash_attn_2_inference_equivalence(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn_2:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
@@ -3260,9 +3260,7 @@ class ModelTesterMixin:
|
|||||||
)
|
)
|
||||||
model_fa.to(torch_device)
|
model_fa.to(torch_device)
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
|
||||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
|
||||||
)
|
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|
||||||
dummy_input = inputs_dict[model.main_input_name][:1]
|
dummy_input = inputs_dict[model.main_input_name][:1]
|
||||||
@@ -3340,7 +3338,7 @@ class ModelTesterMixin:
|
|||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@mark.flash_attn_test
|
@mark.flash_attn_test
|
||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_inference_padding_right(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn_2:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
@@ -3355,9 +3353,7 @@ class ModelTesterMixin:
|
|||||||
)
|
)
|
||||||
model_fa.to(torch_device)
|
model_fa.to(torch_device)
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
|
||||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
|
||||||
)
|
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|
||||||
dummy_input = inputs_dict[model.main_input_name][:1]
|
dummy_input = inputs_dict[model.main_input_name][:1]
|
||||||
|
|||||||
Reference in New Issue
Block a user