[Feature] Support using FlashAttention2 on Ascend NPU (#36696)

* [Feature] Support using flash-attention on Ascend NPU

* Fix qwen3 and qwen3_moe moduler conversion mismatch
This commit is contained in:
Zhen
2025-03-31 22:12:58 +08:00
committed by GitHub
parent a03cee7a1d
commit e686fed635
55 changed files with 447 additions and 234 deletions

View File

@@ -2276,11 +2276,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
if importlib.util.find_spec("flash_attn") is None:
# package `flash-attn` can not be installed on Ascend NPU, ignore related validation logic and early exit.
if is_torch_npu_available():
recommend_message_npu = "You should use attn_implementation='sdpa' instead when using NPU. "
raise ImportError(
f"{preface} the package flash_attn is not supported on Ascend NPU. {recommend_message_npu}"
)
if not hard_check_only:
config._attn_implementation = "flash_attention_2"
logger.info("Detect using FlashAttention2 on Ascend NPU.")
return config
else:
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")