add recommendations for NPU using flash_attn (#36383)
* add recommendations for Ascend NPU using flash_attn * update recommend_message_npu Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -97,6 +97,7 @@ from .utils import (
|
|||||||
is_safetensors_available,
|
is_safetensors_available,
|
||||||
is_torch_flex_attn_available,
|
is_torch_flex_attn_available,
|
||||||
is_torch_greater_or_equal,
|
is_torch_greater_or_equal,
|
||||||
|
is_torch_npu_available,
|
||||||
is_torch_sdpa_available,
|
is_torch_sdpa_available,
|
||||||
is_torch_xla_available,
|
is_torch_xla_available,
|
||||||
logging,
|
logging,
|
||||||
@@ -1746,7 +1747,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
|
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
|
||||||
|
|
||||||
if importlib.util.find_spec("flash_attn") is None:
|
if importlib.util.find_spec("flash_attn") is None:
|
||||||
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
|
if is_torch_npu_available():
|
||||||
|
recommend_message_npu = "You should use attn_implementation='sdpa' instead when using NPU. "
|
||||||
|
raise ImportError(
|
||||||
|
f"{preface} the package flash_attn is not supported on Ascend NPU. {recommend_message_npu}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
|
||||||
|
|
||||||
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
|
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
|
||||||
if torch.version.cuda:
|
if torch.version.cuda:
|
||||||
|
|||||||
@@ -958,6 +958,11 @@ def is_flash_attn_2_available():
|
|||||||
if not (torch.cuda.is_available() or is_torch_mlu_available()):
|
if not (torch.cuda.is_available() or is_torch_mlu_available()):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Ascend does not support "flash_attn".
|
||||||
|
# If "flash_attn" is left in the env, is_flash_attn_2_available() should return False.
|
||||||
|
if is_torch_npu_available():
|
||||||
|
return False
|
||||||
|
|
||||||
if torch.version.cuda:
|
if torch.version.cuda:
|
||||||
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
|
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
|
||||||
elif torch.version.hip:
|
elif torch.version.hip:
|
||||||
|
|||||||
Reference in New Issue
Block a user