@@ -70,7 +70,7 @@ from .utils import (
|
||||
is_accelerate_available,
|
||||
is_auto_gptq_available,
|
||||
is_bitsandbytes_available,
|
||||
is_flash_attn_available,
|
||||
is_flash_attn_2_available,
|
||||
is_offline_mode,
|
||||
is_optimum_available,
|
||||
is_peft_available,
|
||||
@@ -1269,7 +1269,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"request support for this architecture: https://github.com/huggingface/transformers/issues/new"
|
||||
)
|
||||
|
||||
if not is_flash_attn_available():
|
||||
if not is_flash_attn_2_available():
|
||||
raise ImportError(
|
||||
"Flash Attention 2.0 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for"
|
||||
" installing it."
|
||||
|
||||
@@ -35,13 +35,13 @@ from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_available,
|
||||
is_flash_attn_2_available,
|
||||
logging,
|
||||
)
|
||||
from .configuration_falcon import FalconConfig
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
|
||||
@@ -34,14 +34,14 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_available,
|
||||
is_flash_attn_2_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_llama import LlamaConfig
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
|
||||
@@ -34,14 +34,14 @@ from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_available,
|
||||
is_flash_attn_2_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_mistral import MistralConfig
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ from .utils import (
|
||||
is_detectron2_available,
|
||||
is_essentia_available,
|
||||
is_faiss_available,
|
||||
is_flash_attn_available,
|
||||
is_flash_attn_2_available,
|
||||
is_flax_available,
|
||||
is_fsdp_available,
|
||||
is_ftfy_available,
|
||||
@@ -432,7 +432,7 @@ def require_flash_attn(test_case):
|
||||
These tests are skipped when Flash Attention isn't installed.
|
||||
|
||||
"""
|
||||
return unittest.skipUnless(is_flash_attn_available(), "test requires Flash Attention")(test_case)
|
||||
return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
|
||||
|
||||
|
||||
def require_peft(test_case):
|
||||
|
||||
@@ -115,7 +115,7 @@ from .import_utils import (
|
||||
is_detectron2_available,
|
||||
is_essentia_available,
|
||||
is_faiss_available,
|
||||
is_flash_attn_available,
|
||||
is_flash_attn_2_available,
|
||||
is_flax_available,
|
||||
is_fsdp_available,
|
||||
is_ftfy_available,
|
||||
|
||||
@@ -71,7 +71,9 @@ TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
|
||||
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
||||
_apex_available = _is_package_available("apex")
|
||||
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
||||
_flash_attn_available = _is_package_available("flash_attn")
|
||||
_flash_attn_2_available = _is_package_available("flash_attn") and version.parse(
|
||||
importlib.metadata.version("flash_attn")
|
||||
) >= version.parse("2.0.0")
|
||||
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
||||
_bs4_available = importlib.util.find_spec("bs4") is not None
|
||||
_coloredlogs_available = _is_package_available("coloredlogs")
|
||||
@@ -579,14 +581,14 @@ def is_bitsandbytes_available():
|
||||
return _bitsandbytes_available and torch.cuda.is_available()
|
||||
|
||||
|
||||
def is_flash_attn_available():
|
||||
def is_flash_attn_2_available():
|
||||
if not is_torch_available():
|
||||
return False
|
||||
|
||||
# Let's add an extra check to see if cuda is available
|
||||
import torch
|
||||
|
||||
return _flash_attn_available and torch.cuda.is_available()
|
||||
return _flash_attn_2_available and torch.cuda.is_available()
|
||||
|
||||
|
||||
def is_torchdistx_available():
|
||||
|
||||
Reference in New Issue
Block a user