[Feature] Support using FlashAttention2 on Ascend NPU (#36696)

* [Feature] Support using flash-attention on Ascend NPU

* Fix qwen3 and qwen3_moe moduler conversion mismatch
This commit is contained in:
Zhen
2025-03-31 22:12:58 +08:00
committed by GitHub
parent a03cee7a1d
commit e686fed635
55 changed files with 447 additions and 234 deletions

View File

@@ -48,6 +48,7 @@ from transformers import (
is_torch_available,
logging,
)
from transformers.modeling_flash_attention_utils import is_flash_attn_available
from transformers.testing_utils import (
TOKEN,
CaptureLogger,
@@ -79,6 +80,7 @@ from transformers.utils.import_utils import (
is_flash_attn_2_available,
is_flax_available,
is_tf_available,
is_torch_npu_available,
is_torch_sdpa_available,
is_torchdynamo_available,
)
@@ -653,7 +655,7 @@ class ModelUtilsTest(TestCasePlus):
if is_torch_sdpa_available():
attn_implementation_available.append("sdpa")
if is_flash_attn_2_available():
if is_flash_attn_available():
attn_implementation_available.append("flash_attention_2")
for requested_attn_implementation in attn_implementation_available:
@@ -677,7 +679,7 @@ class ModelUtilsTest(TestCasePlus):
if is_torch_sdpa_available():
attn_implementation_available.append("sdpa")
if is_flash_attn_2_available():
if is_flash_attn_available():
attn_implementation_available.append("flash_attention_2")
for requested_attn_implementation in attn_implementation_available:
@@ -2676,6 +2678,11 @@ class TestAttentionImplementation(unittest.TestCase):
if is_flash_attn_2_available():
self.skipTest(reason="Please uninstall flash-attn package to run test_not_available_flash")
if is_torch_npu_available():
self.skipTest(
reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
)
with self.assertRaises(ImportError) as cm:
_ = AutoModel.from_pretrained(
"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
@@ -2686,6 +2693,11 @@ class TestAttentionImplementation(unittest.TestCase):
if is_flash_attn_2_available():
self.skipTest(reason="Please uninstall flash-attn package to run test_not_available_flash")
if is_torch_npu_available():
self.skipTest(
reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
)
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel")
with self.assertRaises(ImportError) as cm: