add the missing flash attention test marker (#32419)

* add flash attention check

* fix

* fix

* add the missing marker

* bug fix

* add one more

* remove order

* add one more
This commit is contained in:
Fanli Lin
2024-08-06 18:18:58 +08:00
committed by GitHub
parent 0aa8328293
commit e85d86398a
7 changed files with 9 additions and 2 deletions

View File

@@ -628,9 +628,9 @@ class GemmaIntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXTS) self.assertEqual(output_text, EXPECTED_TEXTS)
@pytest.mark.flash_attn_test
@require_flash_attn @require_flash_attn
@require_read_token @require_read_token
@pytest.mark.flash_attn_test
def test_model_2b_flash_attn(self): def test_model_2b_flash_attn(self):
model_id = "google/gemma-2b" model_id = "google/gemma-2b"
EXPECTED_TEXTS = [ EXPECTED_TEXTS = [

View File

@@ -620,6 +620,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@slow @slow
@pytest.mark.flash_attn_test
def test_use_flash_attention_2_true(self): def test_use_flash_attention_2_true(self):
""" """
NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended. NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended.

View File

@@ -576,9 +576,10 @@ class MistralIntegrationTest(unittest.TestCase):
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
gc.collect() gc.collect()
@require_flash_attn
@require_bitsandbytes @require_bitsandbytes
@slow @slow
@require_flash_attn @pytest.mark.flash_attn_test
def test_model_7b_long_prompt(self): def test_model_7b_long_prompt(self):
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
# An input with 4097 tokens that is above the size of the sliding window # An input with 4097 tokens that is above the size of the sliding window

View File

@@ -544,6 +544,7 @@ class Qwen2IntegrationTest(unittest.TestCase):
@require_bitsandbytes @require_bitsandbytes
@slow @slow
@require_flash_attn @require_flash_attn
@pytest.mark.flash_attn_test
def test_model_450m_long_prompt(self): def test_model_450m_long_prompt(self):
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
# An input with 4097 tokens that is above the size of the sliding window # An input with 4097 tokens that is above the size of the sliding window

View File

@@ -606,6 +606,7 @@ class Qwen2MoeIntegrationTest(unittest.TestCase):
@require_bitsandbytes @require_bitsandbytes
@slow @slow
@require_flash_attn @require_flash_attn
@pytest.mark.flash_attn_test
def test_model_a2_7b_long_prompt(self): def test_model_a2_7b_long_prompt(self):
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
# An input with 4097 tokens that is above the size of the sliding window # An input with 4097 tokens that is above the size of the sliding window

View File

@@ -16,6 +16,7 @@
import unittest import unittest
import pytest
from parameterized import parameterized from parameterized import parameterized
from transformers import StableLmConfig, is_torch_available, set_seed from transformers import StableLmConfig, is_torch_available, set_seed
@@ -539,6 +540,7 @@ class StableLmModelIntegrationTest(unittest.TestCase):
@require_bitsandbytes @require_bitsandbytes
@slow @slow
@require_flash_attn @require_flash_attn
@pytest.mark.flash_attn_test
def test_model_3b_long_prompt(self): def test_model_3b_long_prompt(self):
EXPECTED_OUTPUT_TOKEN_IDS = [3, 3, 3] EXPECTED_OUTPUT_TOKEN_IDS = [3, 3, 3]
input_ids = [306, 338] * 2047 input_ids = [306, 338] * 2047

View File

@@ -528,6 +528,7 @@ class Starcoder2IntegrationTest(unittest.TestCase):
self.assertEqual(EXPECTED_TEXT, output_text) self.assertEqual(EXPECTED_TEXT, output_text)
@require_flash_attn @require_flash_attn
@pytest.mark.flash_attn_test
def test_starcoder2_batched_generation_fa2(self): def test_starcoder2_batched_generation_fa2(self):
EXPECTED_TEXT = [ EXPECTED_TEXT = [
"Hello my name is Younes and I am a student at the University of Liverpool. I am currently studying for my MSc in Computer Science. I am interested in the field of Machine Learning and I am currently working on", "Hello my name is Younes and I am a student at the University of Liverpool. I am currently studying for my MSc in Computer Science. I am interested in the field of Machine Learning and I am currently working on",