From 6df9179c1c5b19674ce0a4b82d311ef833c02447 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 13 Oct 2023 12:56:50 +0200 Subject: [PATCH] [`core`] Fix fa-2 import (#26785) * fix fa-2 import * nit --- src/transformers/modeling_utils.py | 4 ++-- src/transformers/models/falcon/modeling_falcon.py | 4 ++-- src/transformers/models/llama/modeling_llama.py | 4 ++-- src/transformers/models/mistral/modeling_mistral.py | 4 ++-- src/transformers/testing_utils.py | 4 ++-- src/transformers/utils/__init__.py | 2 +- src/transformers/utils/import_utils.py | 8 +++++--- 7 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 95694a867d..0936d78516 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -70,7 +70,7 @@ from .utils import ( is_accelerate_available, is_auto_gptq_available, is_bitsandbytes_available, - is_flash_attn_available, + is_flash_attn_2_available, is_offline_mode, is_optimum_available, is_peft_available, @@ -1269,7 +1269,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix "request support for this architecture: https://github.com/huggingface/transformers/issues/new" ) - if not is_flash_attn_available(): + if not is_flash_attn_2_available(): raise ImportError( "Flash Attention 2.0 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for" " installing it." diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 33b9fdde73..35313e8d9e 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -35,13 +35,13 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_available, + is_flash_attn_2_available, logging, ) from .configuration_falcon import FalconConfig -if is_flash_attn_available(): +if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 4afa3293ed..b697387f5f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -34,14 +34,14 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_available, + is_flash_attn_2_available, logging, replace_return_docstrings, ) from .configuration_llama import LlamaConfig -if is_flash_attn_available(): +if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 62610ceb41..28d7b914d6 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -34,14 +34,14 @@ from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_available, + is_flash_attn_2_available, logging, replace_return_docstrings, ) from .configuration_mistral import MistralConfig -if is_flash_attn_available(): +if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 341e6cd168..50f50c83c4 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -60,7 +60,7 @@ from .utils import ( is_detectron2_available, is_essentia_available, is_faiss_available, - is_flash_attn_available, + is_flash_attn_2_available, is_flax_available, is_fsdp_available, is_ftfy_available, @@ -432,7 +432,7 @@ def require_flash_attn(test_case): These tests are skipped when Flash Attention isn't installed. """ - return unittest.skipUnless(is_flash_attn_available(), "test requires Flash Attention")(test_case) + return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case) def require_peft(test_case): diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 08f9fc8264..b66f0db389 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -115,7 +115,7 @@ from .import_utils import ( is_detectron2_available, is_essentia_available, is_faiss_available, - is_flash_attn_available, + is_flash_attn_2_available, is_flax_available, is_fsdp_available, is_ftfy_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index e2920b291a..fa5952c4fb 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -71,7 +71,9 @@ TORCH_FX_REQUIRED_VERSION = version.parse("1.10") _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) _apex_available = _is_package_available("apex") _bitsandbytes_available = _is_package_available("bitsandbytes") -_flash_attn_available = _is_package_available("flash_attn") +_flash_attn_2_available = _is_package_available("flash_attn") and version.parse( + importlib.metadata.version("flash_attn") +) >= version.parse("2.0.0") # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. _bs4_available = importlib.util.find_spec("bs4") is not None _coloredlogs_available = _is_package_available("coloredlogs") @@ -579,14 +581,14 @@ def is_bitsandbytes_available(): return _bitsandbytes_available and torch.cuda.is_available() -def is_flash_attn_available(): +def is_flash_attn_2_available(): if not is_torch_available(): return False # Let's add an extra check to see if cuda is available import torch - return _flash_attn_available and torch.cuda.is_available() + return _flash_attn_2_available and torch.cuda.is_available() def is_torchdistx_available():