[Feature] Support using FlashAttention2 on Ascend NPU (#36696)
* [Feature] Support using flash-attention on Ascend NPU * Fix qwen3 and qwen3_moe moduler conversion mismatch
This commit is contained in:
@@ -48,6 +48,7 @@ from transformers import (
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from transformers.modeling_flash_attention_utils import is_flash_attn_available
|
||||
from transformers.testing_utils import (
|
||||
TOKEN,
|
||||
CaptureLogger,
|
||||
@@ -79,6 +80,7 @@ from transformers.utils.import_utils import (
|
||||
is_flash_attn_2_available,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_sdpa_available,
|
||||
is_torchdynamo_available,
|
||||
)
|
||||
@@ -653,7 +655,7 @@ class ModelUtilsTest(TestCasePlus):
|
||||
if is_torch_sdpa_available():
|
||||
attn_implementation_available.append("sdpa")
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
if is_flash_attn_available():
|
||||
attn_implementation_available.append("flash_attention_2")
|
||||
|
||||
for requested_attn_implementation in attn_implementation_available:
|
||||
@@ -677,7 +679,7 @@ class ModelUtilsTest(TestCasePlus):
|
||||
if is_torch_sdpa_available():
|
||||
attn_implementation_available.append("sdpa")
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
if is_flash_attn_available():
|
||||
attn_implementation_available.append("flash_attention_2")
|
||||
|
||||
for requested_attn_implementation in attn_implementation_available:
|
||||
@@ -2676,6 +2678,11 @@ class TestAttentionImplementation(unittest.TestCase):
|
||||
if is_flash_attn_2_available():
|
||||
self.skipTest(reason="Please uninstall flash-attn package to run test_not_available_flash")
|
||||
|
||||
if is_torch_npu_available():
|
||||
self.skipTest(
|
||||
reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
|
||||
)
|
||||
|
||||
with self.assertRaises(ImportError) as cm:
|
||||
_ = AutoModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
|
||||
@@ -2686,6 +2693,11 @@ class TestAttentionImplementation(unittest.TestCase):
|
||||
if is_flash_attn_2_available():
|
||||
self.skipTest(reason="Please uninstall flash-attn package to run test_not_available_flash")
|
||||
|
||||
if is_torch_npu_available():
|
||||
self.skipTest(
|
||||
reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel")
|
||||
|
||||
with self.assertRaises(ImportError) as cm:
|
||||
|
||||
Reference in New Issue
Block a user