add recommendations for NPU using flash_attn (#36383)

* add recommendations for Ascend NPU using flash_attn

* update recommend_message_npu

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
zheliuyu
2025-02-26 21:51:08 +08:00
committed by GitHub
parent b4965cecc5
commit 6513e5e402
2 changed files with 13 additions and 1 deletions

View File

@@ -97,6 +97,7 @@ from .utils import (
is_safetensors_available, is_safetensors_available,
is_torch_flex_attn_available, is_torch_flex_attn_available,
is_torch_greater_or_equal, is_torch_greater_or_equal,
is_torch_npu_available,
is_torch_sdpa_available, is_torch_sdpa_available,
is_torch_xla_available, is_torch_xla_available,
logging, logging,
@@ -1746,7 +1747,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2." install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
if importlib.util.find_spec("flash_attn") is None: if importlib.util.find_spec("flash_attn") is None:
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}") if is_torch_npu_available():
recommend_message_npu = "You should use attn_implementation='sdpa' instead when using NPU. "
raise ImportError(
f"{preface} the package flash_attn is not supported on Ascend NPU. {recommend_message_npu}"
)
else:
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
if torch.version.cuda: if torch.version.cuda:

View File

@@ -958,6 +958,11 @@ def is_flash_attn_2_available():
if not (torch.cuda.is_available() or is_torch_mlu_available()): if not (torch.cuda.is_available() or is_torch_mlu_available()):
return False return False
# Ascend does not support "flash_attn".
# If "flash_attn" is left in the env, is_flash_attn_2_available() should return False.
if is_torch_npu_available():
return False
if torch.version.cuda: if torch.version.cuda:
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0") return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
elif torch.version.hip: elif torch.version.hip: