add the missing flash attention test marker (#32419)

* add flash attention check

* fix

* fix

* add the missing marker

* bug fix

* add one more

* remove order

* add one more
This commit is contained in:
Fanli Lin
2024-08-06 18:18:58 +08:00
committed by GitHub
parent 0aa8328293
commit e85d86398a
7 changed files with 9 additions and 2 deletions

View File

@@ -620,6 +620,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@require_flash_attn
@require_torch_gpu
@slow
@pytest.mark.flash_attn_test
def test_use_flash_attention_2_true(self):
"""
NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended.