[Feature] Support using FlashAttention2 on Ascend NPU (#36696)
* [Feature] Support using flash-attention on Ascend NPU * Fix qwen3 and qwen3_moe moduler conversion mismatch
This commit is contained in:
@@ -2276,11 +2276,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
|
||||
|
||||
if importlib.util.find_spec("flash_attn") is None:
|
||||
# package `flash-attn` can not be installed on Ascend NPU, ignore related validation logic and early exit.
|
||||
if is_torch_npu_available():
|
||||
recommend_message_npu = "You should use attn_implementation='sdpa' instead when using NPU. "
|
||||
raise ImportError(
|
||||
f"{preface} the package flash_attn is not supported on Ascend NPU. {recommend_message_npu}"
|
||||
)
|
||||
if not hard_check_only:
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
|
||||
logger.info("Detect using FlashAttention2 on Ascend NPU.")
|
||||
return config
|
||||
else:
|
||||
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user