From 6513e5e402a91d3378c9c05c385a3f4f1f394cbb Mon Sep 17 00:00:00 2001 From: zheliuyu <15750543867@163.com> Date: Wed, 26 Feb 2025 21:51:08 +0800 Subject: [PATCH] add recommendations for NPU using flash_attn (#36383) * add recommendations for Ascend NPU using flash_attn * update recommend_message_npu Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/modeling_utils.py | 9 ++++++++- src/transformers/utils/import_utils.py | 5 +++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d7abc3bf7e..6321757e42 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -97,6 +97,7 @@ from .utils import ( is_safetensors_available, is_torch_flex_attn_available, is_torch_greater_or_equal, + is_torch_npu_available, is_torch_sdpa_available, is_torch_xla_available, logging, @@ -1746,7 +1747,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2." if importlib.util.find_spec("flash_attn") is None: - raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}") + if is_torch_npu_available(): + recommend_message_npu = "You should use attn_implementation='sdpa' instead when using NPU. " + raise ImportError( + f"{preface} the package flash_attn is not supported on Ascend NPU. {recommend_message_npu}" + ) + else: + raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}") flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) if torch.version.cuda: diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index c9007d909c..a4a566d1c6 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -958,6 +958,11 @@ def is_flash_attn_2_available(): if not (torch.cuda.is_available() or is_torch_mlu_available()): return False + # Ascend does not support "flash_attn". + # If "flash_attn" is left in the env, is_flash_attn_2_available() should return False. + if is_torch_npu_available(): + return False + if torch.version.cuda: return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0") elif torch.version.hip: