add the missing flash attention test marker (#32419)

* add flash attention check

* fix

* fix

* add the missing marker

* bug fix

* add one more

* remove order

* add one more
This commit is contained in:
Fanli Lin
2024-08-06 18:18:58 +08:00
committed by GitHub
parent 0aa8328293
commit e85d86398a
7 changed files with 9 additions and 2 deletions

View File

@@ -16,6 +16,7 @@
import unittest
import pytest
from parameterized import parameterized
from transformers import StableLmConfig, is_torch_available, set_seed
@@ -539,6 +540,7 @@ class StableLmModelIntegrationTest(unittest.TestCase):
@require_bitsandbytes
@slow
@require_flash_attn
@pytest.mark.flash_attn_test
def test_model_3b_long_prompt(self):
EXPECTED_OUTPUT_TOKEN_IDS = [3, 3, 3]
input_ids = [306, 338] * 2047