@@ -70,7 +70,7 @@ from .utils import (
|
|||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_auto_gptq_available,
|
is_auto_gptq_available,
|
||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
is_flash_attn_available,
|
is_flash_attn_2_available,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_optimum_available,
|
is_optimum_available,
|
||||||
is_peft_available,
|
is_peft_available,
|
||||||
@@ -1269,7 +1269,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
"request support for this architecture: https://github.com/huggingface/transformers/issues/new"
|
"request support for this architecture: https://github.com/huggingface/transformers/issues/new"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not is_flash_attn_available():
|
if not is_flash_attn_2_available():
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Flash Attention 2.0 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for"
|
"Flash Attention 2.0 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for"
|
||||||
" installing it."
|
" installing it."
|
||||||
|
|||||||
@@ -35,13 +35,13 @@ from ...utils import (
|
|||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
is_flash_attn_available,
|
is_flash_attn_2_available,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from .configuration_falcon import FalconConfig
|
from .configuration_falcon import FalconConfig
|
||||||
|
|
||||||
|
|
||||||
if is_flash_attn_available():
|
if is_flash_attn_2_available():
|
||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||||
|
|
||||||
|
|||||||
@@ -34,14 +34,14 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
|||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
is_flash_attn_available,
|
is_flash_attn_2_available,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from .configuration_llama import LlamaConfig
|
from .configuration_llama import LlamaConfig
|
||||||
|
|
||||||
|
|
||||||
if is_flash_attn_available():
|
if is_flash_attn_2_available():
|
||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||||
|
|
||||||
|
|||||||
@@ -34,14 +34,14 @@ from ...modeling_utils import PreTrainedModel
|
|||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
is_flash_attn_available,
|
is_flash_attn_2_available,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from .configuration_mistral import MistralConfig
|
from .configuration_mistral import MistralConfig
|
||||||
|
|
||||||
|
|
||||||
if is_flash_attn_available():
|
if is_flash_attn_2_available():
|
||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ from .utils import (
|
|||||||
is_detectron2_available,
|
is_detectron2_available,
|
||||||
is_essentia_available,
|
is_essentia_available,
|
||||||
is_faiss_available,
|
is_faiss_available,
|
||||||
is_flash_attn_available,
|
is_flash_attn_2_available,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_fsdp_available,
|
is_fsdp_available,
|
||||||
is_ftfy_available,
|
is_ftfy_available,
|
||||||
@@ -432,7 +432,7 @@ def require_flash_attn(test_case):
|
|||||||
These tests are skipped when Flash Attention isn't installed.
|
These tests are skipped when Flash Attention isn't installed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return unittest.skipUnless(is_flash_attn_available(), "test requires Flash Attention")(test_case)
|
return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_peft(test_case):
|
def require_peft(test_case):
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ from .import_utils import (
|
|||||||
is_detectron2_available,
|
is_detectron2_available,
|
||||||
is_essentia_available,
|
is_essentia_available,
|
||||||
is_faiss_available,
|
is_faiss_available,
|
||||||
is_flash_attn_available,
|
is_flash_attn_2_available,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_fsdp_available,
|
is_fsdp_available,
|
||||||
is_ftfy_available,
|
is_ftfy_available,
|
||||||
|
|||||||
@@ -71,7 +71,9 @@ TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
|
|||||||
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
||||||
_apex_available = _is_package_available("apex")
|
_apex_available = _is_package_available("apex")
|
||||||
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
||||||
_flash_attn_available = _is_package_available("flash_attn")
|
_flash_attn_2_available = _is_package_available("flash_attn") and version.parse(
|
||||||
|
importlib.metadata.version("flash_attn")
|
||||||
|
) >= version.parse("2.0.0")
|
||||||
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
||||||
_bs4_available = importlib.util.find_spec("bs4") is not None
|
_bs4_available = importlib.util.find_spec("bs4") is not None
|
||||||
_coloredlogs_available = _is_package_available("coloredlogs")
|
_coloredlogs_available = _is_package_available("coloredlogs")
|
||||||
@@ -579,14 +581,14 @@ def is_bitsandbytes_available():
|
|||||||
return _bitsandbytes_available and torch.cuda.is_available()
|
return _bitsandbytes_available and torch.cuda.is_available()
|
||||||
|
|
||||||
|
|
||||||
def is_flash_attn_available():
|
def is_flash_attn_2_available():
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Let's add an extra check to see if cuda is available
|
# Let's add an extra check to see if cuda is available
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
return _flash_attn_available and torch.cuda.is_available()
|
return _flash_attn_2_available and torch.cuda.is_available()
|
||||||
|
|
||||||
|
|
||||||
def is_torchdistx_available():
|
def is_torchdistx_available():
|
||||||
|
|||||||
Reference in New Issue
Block a user