add recommendations for NPU using flash_attn (#36383)
* add recommendations for Ascend NPU using flash_attn * update recommend_message_npu Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -97,6 +97,7 @@ from .utils import (
|
||||
is_safetensors_available,
|
||||
is_torch_flex_attn_available,
|
||||
is_torch_greater_or_equal,
|
||||
is_torch_npu_available,
|
||||
is_torch_sdpa_available,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
@@ -1746,7 +1747,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
|
||||
|
||||
if importlib.util.find_spec("flash_attn") is None:
|
||||
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
|
||||
if is_torch_npu_available():
|
||||
recommend_message_npu = "You should use attn_implementation='sdpa' instead when using NPU. "
|
||||
raise ImportError(
|
||||
f"{preface} the package flash_attn is not supported on Ascend NPU. {recommend_message_npu}"
|
||||
)
|
||||
else:
|
||||
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
|
||||
|
||||
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
|
||||
if torch.version.cuda:
|
||||
|
||||
Reference in New Issue
Block a user