@@ -60,7 +60,7 @@ from .utils import (
|
||||
is_detectron2_available,
|
||||
is_essentia_available,
|
||||
is_faiss_available,
|
||||
is_flash_attn_available,
|
||||
is_flash_attn_2_available,
|
||||
is_flax_available,
|
||||
is_fsdp_available,
|
||||
is_ftfy_available,
|
||||
@@ -432,7 +432,7 @@ def require_flash_attn(test_case):
|
||||
These tests are skipped when Flash Attention isn't installed.
|
||||
|
||||
"""
|
||||
return unittest.skipUnless(is_flash_attn_available(), "test requires Flash Attention")(test_case)
|
||||
return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
|
||||
|
||||
|
||||
def require_peft(test_case):
|
||||
|
||||
Reference in New Issue
Block a user