add the missing flash attention test marker (#32419)
* add flash attention check * fix * fix * add the missing marker * bug fix * add one more * remove order * add one more
This commit is contained in:
@@ -628,9 +628,9 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@pytest.mark.flash_attn_test
|
||||
@require_flash_attn
|
||||
@require_read_token
|
||||
@pytest.mark.flash_attn_test
|
||||
def test_model_2b_flash_attn(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
|
||||
@@ -620,6 +620,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
@pytest.mark.flash_attn_test
|
||||
def test_use_flash_attention_2_true(self):
|
||||
"""
|
||||
NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended.
|
||||
|
||||
@@ -576,9 +576,10 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
@require_flash_attn
|
||||
@require_bitsandbytes
|
||||
@slow
|
||||
@require_flash_attn
|
||||
@pytest.mark.flash_attn_test
|
||||
def test_model_7b_long_prompt(self):
|
||||
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
|
||||
# An input with 4097 tokens that is above the size of the sliding window
|
||||
|
||||
@@ -544,6 +544,7 @@ class Qwen2IntegrationTest(unittest.TestCase):
|
||||
@require_bitsandbytes
|
||||
@slow
|
||||
@require_flash_attn
|
||||
@pytest.mark.flash_attn_test
|
||||
def test_model_450m_long_prompt(self):
|
||||
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
|
||||
# An input with 4097 tokens that is above the size of the sliding window
|
||||
|
||||
@@ -606,6 +606,7 @@ class Qwen2MoeIntegrationTest(unittest.TestCase):
|
||||
@require_bitsandbytes
|
||||
@slow
|
||||
@require_flash_attn
|
||||
@pytest.mark.flash_attn_test
|
||||
def test_model_a2_7b_long_prompt(self):
|
||||
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
|
||||
# An input with 4097 tokens that is above the size of the sliding window
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import StableLmConfig, is_torch_available, set_seed
|
||||
@@ -539,6 +540,7 @@ class StableLmModelIntegrationTest(unittest.TestCase):
|
||||
@require_bitsandbytes
|
||||
@slow
|
||||
@require_flash_attn
|
||||
@pytest.mark.flash_attn_test
|
||||
def test_model_3b_long_prompt(self):
|
||||
EXPECTED_OUTPUT_TOKEN_IDS = [3, 3, 3]
|
||||
input_ids = [306, 338] * 2047
|
||||
|
||||
@@ -528,6 +528,7 @@ class Starcoder2IntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(EXPECTED_TEXT, output_text)
|
||||
|
||||
@require_flash_attn
|
||||
@pytest.mark.flash_attn_test
|
||||
def test_starcoder2_batched_generation_fa2(self):
|
||||
EXPECTED_TEXT = [
|
||||
"Hello my name is Younes and I am a student at the University of Liverpool. I am currently studying for my MSc in Computer Science. I am interested in the field of Machine Learning and I am currently working on",
|
||||
|
||||
Reference in New Issue
Block a user