Better check for packages availability (#23163)
* Better check for packages availability * amend _optimumneuron_available * amend torch_version * amend PIL detection and lint * lint * amend _faiss_available * remove overloaded signatures of _is_package_available * fix sklearn and decord detection * remove unused checks * revert
This commit is contained in:
committed by
GitHub
parent
d51296d9c2
commit
83eda6435e
@@ -72,6 +72,7 @@ from .utils import (
|
|||||||
get_cached_models,
|
get_cached_models,
|
||||||
get_file_from_repo,
|
get_file_from_repo,
|
||||||
get_full_repo_name,
|
get_full_repo_name,
|
||||||
|
get_torch_version,
|
||||||
has_file,
|
has_file,
|
||||||
http_user_agent,
|
http_user_agent,
|
||||||
is_apex_available,
|
is_apex_available,
|
||||||
@@ -125,5 +126,4 @@ from .utils import (
|
|||||||
to_numpy,
|
to_numpy,
|
||||||
to_py_obj,
|
to_py_obj,
|
||||||
torch_only_method,
|
torch_only_method,
|
||||||
torch_version,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -232,9 +232,9 @@ class OnnxConfig(ABC):
|
|||||||
`bool`: Whether the installed version of PyTorch is compatible with the model.
|
`bool`: Whether the installed version of PyTorch is compatible with the model.
|
||||||
"""
|
"""
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers.utils import torch_version
|
from transformers.utils import get_torch_version
|
||||||
|
|
||||||
return torch_version >= self.torch_onnx_minimum_version
|
return get_torch_version() >= self.torch_onnx_minimum_version
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -334,12 +334,12 @@ def export(
|
|||||||
preprocessor = tokenizer
|
preprocessor = tokenizer
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from ..utils import torch_version
|
from ..utils import get_torch_version
|
||||||
|
|
||||||
if not config.is_torch_support_available:
|
if not config.is_torch_support_available:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Unsupported PyTorch version for this model. Minimum required is {config.torch_onnx_minimum_version},"
|
f"Unsupported PyTorch version for this model. Minimum required is {config.torch_onnx_minimum_version},"
|
||||||
f" got: {torch_version}"
|
f" got: {get_torch_version()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_torch_available() and issubclass(type(model), PreTrainedModel):
|
if is_torch_available() and issubclass(type(model), PreTrainedModel):
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ from .import_utils import (
|
|||||||
_LazyModule,
|
_LazyModule,
|
||||||
ccl_version,
|
ccl_version,
|
||||||
direct_transformers_import,
|
direct_transformers_import,
|
||||||
|
get_torch_version,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_apex_available,
|
is_apex_available,
|
||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
@@ -170,7 +171,6 @@ from .import_utils import (
|
|||||||
is_vision_available,
|
is_vision_available,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
torch_only_method,
|
torch_only_method,
|
||||||
torch_version,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ import warnings
|
|||||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.fx import Graph, GraphModule, Proxy, Tracer
|
from torch.fx import Graph, GraphModule, Proxy, Tracer
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
@@ -54,8 +53,13 @@ from ..models.auto.modeling_auto import (
|
|||||||
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||||
MODEL_MAPPING_NAMES,
|
MODEL_MAPPING_NAMES,
|
||||||
)
|
)
|
||||||
from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_peft_available, is_torch_fx_available
|
from ..utils import (
|
||||||
from ..utils.versions import importlib_metadata
|
ENV_VARS_TRUE_VALUES,
|
||||||
|
TORCH_FX_REQUIRED_VERSION,
|
||||||
|
get_torch_version,
|
||||||
|
is_peft_available,
|
||||||
|
is_torch_fx_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_peft_available():
|
if is_peft_available():
|
||||||
@@ -737,9 +741,8 @@ class HFTracer(Tracer):
|
|||||||
super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
|
super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
|
||||||
|
|
||||||
if not is_torch_fx_available():
|
if not is_torch_fx_available():
|
||||||
torch_version = version.parse(importlib_metadata.version("torch"))
|
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Found an incompatible version of torch. Found version {torch_version}, but only version "
|
f"Found an incompatible version of torch. Found version {get_torch_version()}, but only version "
|
||||||
f"{TORCH_FX_REQUIRED_VERSION} is supported."
|
f"{TORCH_FX_REQUIRED_VERSION} is supported."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from collections import OrderedDict
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Any
|
from typing import Any, Tuple, Union
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
@@ -35,6 +35,24 @@ from .versions import importlib_metadata
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
|
||||||
|
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
|
||||||
|
package_exists = importlib.util.find_spec(pkg_name) is not None
|
||||||
|
package_version = "N/A"
|
||||||
|
if package_exists:
|
||||||
|
try:
|
||||||
|
package_version = importlib_metadata.version(pkg_name)
|
||||||
|
package_exists = True
|
||||||
|
except importlib_metadata.PackageNotFoundError:
|
||||||
|
package_exists = False
|
||||||
|
logger.debug(f"Detected {pkg_name} version {package_version}")
|
||||||
|
if return_version:
|
||||||
|
return package_exists, package_version
|
||||||
|
else:
|
||||||
|
return package_exists
|
||||||
|
|
||||||
|
|
||||||
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
||||||
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
||||||
|
|
||||||
@@ -44,26 +62,80 @@ USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
|||||||
|
|
||||||
FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
|
FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
|
||||||
|
|
||||||
_torch_version = "N/A"
|
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
|
||||||
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
|
||||||
_torch_available = importlib.util.find_spec("torch") is not None
|
|
||||||
if _torch_available:
|
|
||||||
|
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
||||||
|
_apex_available = _is_package_available("apex")
|
||||||
|
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
||||||
|
_bs4_available = _is_package_available("bs4")
|
||||||
|
_coloredlogs_available = _is_package_available("coloredlogs")
|
||||||
|
_datasets_available = _is_package_available("datasets")
|
||||||
|
_decord_available = importlib.util.find_spec("decord") is not None
|
||||||
|
_detectron2_available = _is_package_available("detectron2")
|
||||||
|
_faiss_available = _is_package_available("faiss") or _is_package_available("faiss-cpu")
|
||||||
|
_ftfy_available = _is_package_available("ftfy")
|
||||||
|
_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True)
|
||||||
|
_jieba_available = _is_package_available("jieba")
|
||||||
|
_kenlm_available = _is_package_available("kenlm")
|
||||||
|
_keras_nlp_available = _is_package_available("keras_nlp")
|
||||||
|
_librosa_available = _is_package_available("librosa")
|
||||||
|
_natten_available = _is_package_available("natten")
|
||||||
|
_ninja_available = _is_package_available("ninja")
|
||||||
|
_onnx_available = _is_package_available("onnx")
|
||||||
|
_openai_available = _is_package_available("openai")
|
||||||
|
_optimum_available = _is_package_available("optimum")
|
||||||
|
_optimumneuron_available = _optimum_available and _is_package_available("optimum.neuron")
|
||||||
|
_pandas_available = _is_package_available("pandas")
|
||||||
|
_peft_available = _is_package_available("peft")
|
||||||
|
_phonemizer_available = _is_package_available("phonemizer")
|
||||||
|
_psutil_available = _is_package_available("psutil")
|
||||||
|
_py3nvml_available = _is_package_available("py3nvml")
|
||||||
|
_pyctcdecode_available = _is_package_available("pyctcdecode")
|
||||||
|
_pytesseract_available = _is_package_available("pytesseract")
|
||||||
|
_pytorch_quantization_available = _is_package_available("pytorch_quantization")
|
||||||
|
_rjieba_available = _is_package_available("rjieba")
|
||||||
|
_sacremoses_available = _is_package_available("sacremoses")
|
||||||
|
_safetensors_available = _is_package_available("safetensors")
|
||||||
|
_scipy_available = _is_package_available("scipy")
|
||||||
|
_sentencepiece_available = _is_package_available("sentencepiece")
|
||||||
|
_sklearn_available = importlib.util.find_spec("sklearn") is not None
|
||||||
|
if _sklearn_available:
|
||||||
try:
|
try:
|
||||||
_torch_version = importlib_metadata.version("torch")
|
importlib_metadata.version("scikit-learn")
|
||||||
logger.info(f"PyTorch version {_torch_version} available.")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
except importlib_metadata.PackageNotFoundError:
|
||||||
_torch_available = False
|
_sklearn_available = False
|
||||||
|
_smdistributed_available = _is_package_available("smdistributed")
|
||||||
|
_soundfile_available = _is_package_available("soundfile")
|
||||||
|
_spacy_available = _is_package_available("spacy")
|
||||||
|
_sudachipy_available = _is_package_available("sudachipy")
|
||||||
|
_tensorflow_probability_available = _is_package_available("tensorflow_probability")
|
||||||
|
_tensorflow_text_available = _is_package_available("tensorflow_text")
|
||||||
|
_tf2onnx_available = _is_package_available("tf2onnx")
|
||||||
|
_timm_available = _is_package_available("timm")
|
||||||
|
_tokenizers_available = _is_package_available("tokenizers")
|
||||||
|
_torchaudio_available = _is_package_available("torchaudio")
|
||||||
|
_torchdistx_available = _is_package_available("torchdistx")
|
||||||
|
_torchvision_available = _is_package_available("torchvision")
|
||||||
|
|
||||||
|
|
||||||
|
_torch_version = "N/A"
|
||||||
|
_torch_available = False
|
||||||
|
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
||||||
|
_torch_available, _torch_version = _is_package_available("torch", return_version=True)
|
||||||
else:
|
else:
|
||||||
logger.info("Disabling PyTorch because USE_TF is set")
|
logger.info("Disabling PyTorch because USE_TF is set")
|
||||||
_torch_available = False
|
_torch_available = False
|
||||||
|
|
||||||
|
|
||||||
_tf_version = "N/A"
|
_tf_version = "N/A"
|
||||||
|
_tf_available = False
|
||||||
if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES:
|
if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES:
|
||||||
_tf_available = True
|
_tf_available = True
|
||||||
else:
|
else:
|
||||||
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
||||||
_tf_available = importlib.util.find_spec("tensorflow") is not None
|
_tf_available = _is_package_available("tensorflow")
|
||||||
if _tf_available:
|
if _tf_available:
|
||||||
candidates = (
|
candidates = (
|
||||||
"tensorflow",
|
"tensorflow",
|
||||||
@@ -93,180 +165,10 @@ else:
|
|||||||
f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum."
|
f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum."
|
||||||
)
|
)
|
||||||
_tf_available = False
|
_tf_available = False
|
||||||
else:
|
|
||||||
logger.info(f"TensorFlow version {_tf_version} available.")
|
|
||||||
else:
|
else:
|
||||||
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
||||||
_tf_available = False
|
|
||||||
|
|
||||||
|
|
||||||
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
|
||||||
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
|
|
||||||
if _flax_available:
|
|
||||||
try:
|
|
||||||
_jax_version = importlib_metadata.version("jax")
|
|
||||||
_flax_version = importlib_metadata.version("flax")
|
|
||||||
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_flax_available = False
|
|
||||||
else:
|
|
||||||
_flax_available = False
|
|
||||||
|
|
||||||
|
|
||||||
_datasets_available = importlib.util.find_spec("datasets") is not None
|
|
||||||
try:
|
|
||||||
# Check we're not importing a "datasets" directory somewhere but the actual library by trying to grab the version
|
|
||||||
# AND checking it has an author field in the metadata that is HuggingFace.
|
|
||||||
_ = importlib_metadata.version("datasets")
|
|
||||||
_datasets_metadata = importlib_metadata.metadata("datasets")
|
|
||||||
if _datasets_metadata.get("author", "") != "HuggingFace Inc.":
|
|
||||||
_datasets_available = False
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_datasets_available = False
|
|
||||||
|
|
||||||
|
|
||||||
_diffusers_available = importlib.util.find_spec("diffusers") is not None
|
|
||||||
try:
|
|
||||||
_diffusers_version = importlib_metadata.version("diffusers")
|
|
||||||
logger.debug(f"Successfully imported diffusers version {_diffusers_version}")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_diffusers_available = False
|
|
||||||
|
|
||||||
|
|
||||||
_detectron2_available = importlib.util.find_spec("detectron2") is not None
|
|
||||||
try:
|
|
||||||
_detectron2_version = importlib_metadata.version("detectron2")
|
|
||||||
logger.debug(f"Successfully imported detectron2 version {_detectron2_version}")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_detectron2_available = False
|
|
||||||
|
|
||||||
|
|
||||||
_faiss_available = importlib.util.find_spec("faiss") is not None
|
|
||||||
try:
|
|
||||||
_faiss_version = importlib_metadata.version("faiss")
|
|
||||||
logger.debug(f"Successfully imported faiss version {_faiss_version}")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
try:
|
|
||||||
_faiss_version = importlib_metadata.version("faiss-cpu")
|
|
||||||
logger.debug(f"Successfully imported faiss version {_faiss_version}")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_faiss_available = False
|
|
||||||
|
|
||||||
|
|
||||||
_ftfy_available = importlib.util.find_spec("ftfy") is not None
|
|
||||||
try:
|
|
||||||
_ftfy_version = importlib_metadata.version("ftfy")
|
|
||||||
logger.debug(f"Successfully imported ftfy version {_ftfy_version}")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_ftfy_available = False
|
|
||||||
|
|
||||||
|
|
||||||
coloredlogs = importlib.util.find_spec("coloredlogs") is not None
|
|
||||||
try:
|
|
||||||
_coloredlogs_available = importlib_metadata.version("coloredlogs")
|
|
||||||
logger.debug(f"Successfully imported sympy version {_coloredlogs_available}")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_coloredlogs_available = False
|
|
||||||
|
|
||||||
|
|
||||||
sympy_available = importlib.util.find_spec("sympy") is not None
|
|
||||||
try:
|
|
||||||
_sympy_available = importlib_metadata.version("sympy")
|
|
||||||
logger.debug(f"Successfully imported sympy version {_sympy_available}")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_sympy_available = False
|
|
||||||
|
|
||||||
|
|
||||||
_tf2onnx_available = importlib.util.find_spec("tf2onnx") is not None
|
|
||||||
try:
|
|
||||||
_tf2onnx_version = importlib_metadata.version("tf2onnx")
|
|
||||||
logger.debug(f"Successfully imported tf2onnx version {_tf2onnx_version}")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_tf2onnx_available = False
|
|
||||||
|
|
||||||
|
|
||||||
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
|
|
||||||
try:
|
|
||||||
_onxx_version = importlib_metadata.version("onnx")
|
|
||||||
logger.debug(f"Successfully imported onnx version {_onxx_version}")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_onnx_available = False
|
|
||||||
|
|
||||||
|
|
||||||
_opencv_available = importlib.util.find_spec("cv2") is not None
|
|
||||||
|
|
||||||
|
|
||||||
_pytorch_quantization_available = importlib.util.find_spec("pytorch_quantization") is not None
|
|
||||||
try:
|
|
||||||
_pytorch_quantization_version = importlib_metadata.version("pytorch_quantization")
|
|
||||||
logger.debug(f"Successfully imported pytorch-quantization version {_pytorch_quantization_version}")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_pytorch_quantization_available = False
|
|
||||||
|
|
||||||
|
|
||||||
_soundfile_available = importlib.util.find_spec("soundfile") is not None
|
|
||||||
try:
|
|
||||||
_soundfile_version = importlib_metadata.version("soundfile")
|
|
||||||
logger.debug(f"Successfully imported soundfile version {_soundfile_version}")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_soundfile_available = False
|
|
||||||
|
|
||||||
|
|
||||||
_tensorflow_probability_available = importlib.util.find_spec("tensorflow_probability") is not None
|
|
||||||
try:
|
|
||||||
_tensorflow_probability_version = importlib_metadata.version("tensorflow_probability")
|
|
||||||
logger.debug(f"Successfully imported tensorflow-probability version {_tensorflow_probability_version}")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_tensorflow_probability_available = False
|
|
||||||
|
|
||||||
|
|
||||||
_timm_available = importlib.util.find_spec("timm") is not None
|
|
||||||
try:
|
|
||||||
_timm_version = importlib_metadata.version("timm")
|
|
||||||
logger.debug(f"Successfully imported timm version {_timm_version}")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_timm_available = False
|
|
||||||
|
|
||||||
|
|
||||||
_natten_available = importlib.util.find_spec("natten") is not None
|
|
||||||
try:
|
|
||||||
_natten_version = importlib_metadata.version("natten")
|
|
||||||
logger.debug(f"Successfully imported natten version {_natten_version}")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_natten_available = False
|
|
||||||
|
|
||||||
|
|
||||||
_torchaudio_available = importlib.util.find_spec("torchaudio") is not None
|
|
||||||
try:
|
|
||||||
_torchaudio_version = importlib_metadata.version("torchaudio")
|
|
||||||
logger.debug(f"Successfully imported torchaudio version {_torchaudio_version}")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_torchaudio_available = False
|
|
||||||
|
|
||||||
|
|
||||||
_phonemizer_available = importlib.util.find_spec("phonemizer") is not None
|
|
||||||
try:
|
|
||||||
_phonemizer_version = importlib_metadata.version("phonemizer")
|
|
||||||
logger.debug(f"Successfully imported phonemizer version {_phonemizer_version}")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_phonemizer_available = False
|
|
||||||
|
|
||||||
|
|
||||||
_pyctcdecode_available = importlib.util.find_spec("pyctcdecode") is not None
|
|
||||||
try:
|
|
||||||
_pyctcdecode_version = importlib_metadata.version("pyctcdecode")
|
|
||||||
logger.debug(f"Successfully imported pyctcdecode version {_pyctcdecode_version}")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_pyctcdecode_available = False
|
|
||||||
|
|
||||||
|
|
||||||
_librosa_available = importlib.util.find_spec("librosa") is not None
|
|
||||||
try:
|
|
||||||
_librosa_version = importlib_metadata.version("librosa")
|
|
||||||
logger.debug(f"Successfully imported librosa version {_librosa_version}")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_librosa_available = False
|
|
||||||
|
|
||||||
ccl_version = "N/A"
|
ccl_version = "N/A"
|
||||||
_is_ccl_available = (
|
_is_ccl_available = (
|
||||||
importlib.util.find_spec("torch_ccl") is not None
|
importlib.util.find_spec("torch_ccl") is not None
|
||||||
@@ -274,38 +176,46 @@ _is_ccl_available = (
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
ccl_version = importlib_metadata.version("oneccl_bind_pt")
|
ccl_version = importlib_metadata.version("oneccl_bind_pt")
|
||||||
logger.debug(f"Successfully imported oneccl_bind_pt version {ccl_version}")
|
logger.debug(f"Detected oneccl_bind_pt version {ccl_version}")
|
||||||
except importlib_metadata.PackageNotFoundError:
|
except importlib_metadata.PackageNotFoundError:
|
||||||
_is_ccl_available = False
|
_is_ccl_available = False
|
||||||
|
|
||||||
_decord_availale = importlib.util.find_spec("decord") is not None
|
|
||||||
try:
|
|
||||||
_decord_version = importlib_metadata.version("decord")
|
|
||||||
logger.debug(f"Successfully imported decord version {_decord_version}")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
_decord_availale = False
|
|
||||||
|
|
||||||
_jieba_available = importlib.util.find_spec("jieba") is not None
|
_flax_available = False
|
||||||
try:
|
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||||
_jieba_version = importlib_metadata.version("jieba")
|
_flax_available, _flax_version = _is_package_available("flax", return_version=True)
|
||||||
logger.debug(f"Successfully imported jieba version {_jieba_version}")
|
if _flax_available:
|
||||||
except importlib_metadata.PackageNotFoundError:
|
_jax_available, _jax_version = _is_package_available("jax", return_version=True)
|
||||||
_jieba_available = False
|
if _jax_available:
|
||||||
|
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
|
||||||
|
else:
|
||||||
|
_flax_available = _jax_available = False
|
||||||
|
_jax_version = _flax_version = "N/A"
|
||||||
|
|
||||||
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
|
|
||||||
TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
|
_torch_fx_available = False
|
||||||
|
if _torch_available:
|
||||||
|
torch_version = version.parse(_torch_version)
|
||||||
|
_torch_fx_available = (torch_version.major, torch_version.minor) >= (
|
||||||
|
TORCH_FX_REQUIRED_VERSION.major,
|
||||||
|
TORCH_FX_REQUIRED_VERSION.minor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_kenlm_available():
|
def is_kenlm_available():
|
||||||
return importlib.util.find_spec("kenlm") is not None
|
return _kenlm_available
|
||||||
|
|
||||||
|
|
||||||
def is_torch_available():
|
def is_torch_available():
|
||||||
return _torch_available
|
return _torch_available
|
||||||
|
|
||||||
|
|
||||||
|
def get_torch_version():
|
||||||
|
return _torch_version
|
||||||
|
|
||||||
|
|
||||||
def is_torchvision_available():
|
def is_torchvision_available():
|
||||||
return importlib.util.find_spec("torchvision") is not None
|
return _torchvision_available
|
||||||
|
|
||||||
|
|
||||||
def is_pyctcdecode_available():
|
def is_pyctcdecode_available():
|
||||||
@@ -404,26 +314,16 @@ def is_torch_tf32_available():
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
torch_version = None
|
|
||||||
_torch_fx_available = False
|
|
||||||
if _torch_available:
|
|
||||||
torch_version = version.parse(importlib_metadata.version("torch"))
|
|
||||||
_torch_fx_available = (torch_version.major, torch_version.minor) >= (
|
|
||||||
TORCH_FX_REQUIRED_VERSION.major,
|
|
||||||
TORCH_FX_REQUIRED_VERSION.minor,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_torch_fx_available():
|
def is_torch_fx_available():
|
||||||
return _torch_fx_available
|
return _torch_fx_available
|
||||||
|
|
||||||
|
|
||||||
def is_peft_available():
|
def is_peft_available():
|
||||||
return importlib.util.find_spec("peft") is not None
|
return _peft_available
|
||||||
|
|
||||||
|
|
||||||
def is_bs4_available():
|
def is_bs4_available():
|
||||||
return importlib.util.find_spec("bs4") is not None
|
return _bs4_available
|
||||||
|
|
||||||
|
|
||||||
def is_tf_available():
|
def is_tf_available():
|
||||||
@@ -443,7 +343,7 @@ def is_onnx_available():
|
|||||||
|
|
||||||
|
|
||||||
def is_openai_available():
|
def is_openai_available():
|
||||||
return importlib.util.find_spec("openai") is not None
|
return _openai_available
|
||||||
|
|
||||||
|
|
||||||
def is_flax_available():
|
def is_flax_available():
|
||||||
@@ -517,40 +417,36 @@ def is_detectron2_available():
|
|||||||
|
|
||||||
|
|
||||||
def is_rjieba_available():
|
def is_rjieba_available():
|
||||||
return importlib.util.find_spec("rjieba") is not None
|
return _rjieba_available
|
||||||
|
|
||||||
|
|
||||||
def is_psutil_available():
|
def is_psutil_available():
|
||||||
return importlib.util.find_spec("psutil") is not None
|
return _psutil_available
|
||||||
|
|
||||||
|
|
||||||
def is_py3nvml_available():
|
def is_py3nvml_available():
|
||||||
return importlib.util.find_spec("py3nvml") is not None
|
return _py3nvml_available
|
||||||
|
|
||||||
|
|
||||||
def is_sacremoses_available():
|
def is_sacremoses_available():
|
||||||
return importlib.util.find_spec("sacremoses") is not None
|
return _sacremoses_available
|
||||||
|
|
||||||
|
|
||||||
def is_apex_available():
|
def is_apex_available():
|
||||||
return importlib.util.find_spec("apex") is not None
|
return _apex_available
|
||||||
|
|
||||||
|
|
||||||
def is_ninja_available():
|
def is_ninja_available():
|
||||||
return importlib.util.find_spec("ninja") is not None
|
return _ninja_available
|
||||||
|
|
||||||
|
|
||||||
def is_ipex_available():
|
def is_ipex_available():
|
||||||
def get_major_and_minor_from_version(full_version):
|
def get_major_and_minor_from_version(full_version):
|
||||||
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
|
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
|
||||||
|
|
||||||
if not is_torch_available() or importlib.util.find_spec("intel_extension_for_pytorch") is None:
|
if not is_torch_available() or not _ipex_available:
|
||||||
return False
|
|
||||||
_ipex_version = "N/A"
|
|
||||||
try:
|
|
||||||
_ipex_version = importlib_metadata.version("intel_extension_for_pytorch")
|
|
||||||
except importlib_metadata.PackageNotFoundError:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
|
torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
|
||||||
ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
|
ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
|
||||||
if torch_major_and_minor != ipex_major_and_minor:
|
if torch_major_and_minor != ipex_major_and_minor:
|
||||||
@@ -563,11 +459,11 @@ def is_ipex_available():
|
|||||||
|
|
||||||
|
|
||||||
def is_bitsandbytes_available():
|
def is_bitsandbytes_available():
|
||||||
return importlib.util.find_spec("bitsandbytes") is not None
|
return _bitsandbytes_available
|
||||||
|
|
||||||
|
|
||||||
def is_torchdistx_available():
|
def is_torchdistx_available():
|
||||||
return importlib.util.find_spec("torchdistx") is not None
|
return _torchdistx_available
|
||||||
|
|
||||||
|
|
||||||
def is_faiss_available():
|
def is_faiss_available():
|
||||||
@@ -575,17 +471,15 @@ def is_faiss_available():
|
|||||||
|
|
||||||
|
|
||||||
def is_scipy_available():
|
def is_scipy_available():
|
||||||
return importlib.util.find_spec("scipy") is not None
|
return _scipy_available
|
||||||
|
|
||||||
|
|
||||||
def is_sklearn_available():
|
def is_sklearn_available():
|
||||||
if importlib.util.find_spec("sklearn") is None:
|
return _sklearn_available
|
||||||
return False
|
|
||||||
return is_scipy_available() and importlib.util.find_spec("sklearn.metrics")
|
|
||||||
|
|
||||||
|
|
||||||
def is_sentencepiece_available():
|
def is_sentencepiece_available():
|
||||||
return importlib.util.find_spec("sentencepiece") is not None
|
return _sentencepiece_available
|
||||||
|
|
||||||
|
|
||||||
def is_protobuf_available():
|
def is_protobuf_available():
|
||||||
@@ -595,56 +489,54 @@ def is_protobuf_available():
|
|||||||
|
|
||||||
|
|
||||||
def is_accelerate_available(check_partial_state=False):
|
def is_accelerate_available(check_partial_state=False):
|
||||||
accelerate_available = importlib.util.find_spec("accelerate") is not None
|
|
||||||
if accelerate_available:
|
|
||||||
if check_partial_state:
|
if check_partial_state:
|
||||||
return version.parse(importlib_metadata.version("accelerate")) >= version.parse("0.17.0")
|
return _accelerate_available and version.parse(_accelerate_version) >= version.parse("0.17.0")
|
||||||
else:
|
return _accelerate_available
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def is_optimum_available():
|
def is_optimum_available():
|
||||||
return importlib.util.find_spec("optimum") is not None
|
return _optimum_available
|
||||||
|
|
||||||
|
|
||||||
def is_optimum_neuron_available():
|
def is_optimum_neuron_available():
|
||||||
return importlib.util.find_spec("optimum.neuron") is not None
|
return _optimumneuron_available
|
||||||
|
|
||||||
|
|
||||||
def is_safetensors_available():
|
def is_safetensors_available():
|
||||||
if is_torch_available():
|
if is_torch_available() and version.parse(_torch_version) < version.parse("1.10"):
|
||||||
if version.parse(_torch_version) >= version.parse("1.10"):
|
|
||||||
return importlib.util.find_spec("safetensors") is not None
|
|
||||||
else:
|
|
||||||
return False
|
return False
|
||||||
else:
|
return _safetensors_available
|
||||||
return importlib.util.find_spec("safetensors") is not None
|
|
||||||
|
|
||||||
|
|
||||||
def is_tokenizers_available():
|
def is_tokenizers_available():
|
||||||
return importlib.util.find_spec("tokenizers") is not None
|
return _tokenizers_available
|
||||||
|
|
||||||
|
|
||||||
def is_vision_available():
|
def is_vision_available():
|
||||||
return importlib.util.find_spec("PIL") is not None
|
_pil_available = importlib.util.find_spec("PIL") is not None
|
||||||
|
if _pil_available:
|
||||||
|
try:
|
||||||
|
package_version = importlib_metadata.version("Pillow")
|
||||||
|
except importlib_metadata.PackageNotFoundError:
|
||||||
|
return False
|
||||||
|
logger.debug(f"Detected PIL version {package_version}")
|
||||||
|
return _pil_available
|
||||||
|
|
||||||
|
|
||||||
def is_pytesseract_available():
|
def is_pytesseract_available():
|
||||||
return importlib.util.find_spec("pytesseract") is not None
|
return _pytesseract_available
|
||||||
|
|
||||||
|
|
||||||
def is_spacy_available():
|
def is_spacy_available():
|
||||||
return importlib.util.find_spec("spacy") is not None
|
return _spacy_available
|
||||||
|
|
||||||
|
|
||||||
def is_tensorflow_text_available():
|
def is_tensorflow_text_available():
|
||||||
return is_tf_available() and importlib.util.find_spec("tensorflow_text") is not None
|
return is_tf_available() and _tensorflow_text_available
|
||||||
|
|
||||||
|
|
||||||
def is_keras_nlp_available():
|
def is_keras_nlp_available():
|
||||||
return is_tensorflow_text_available() and importlib.util.find_spec("keras_nlp") is not None
|
return is_tensorflow_text_available() and _keras_nlp_available
|
||||||
|
|
||||||
|
|
||||||
def is_in_notebook():
|
def is_in_notebook():
|
||||||
@@ -674,7 +566,7 @@ def is_tensorflow_probability_available():
|
|||||||
|
|
||||||
|
|
||||||
def is_pandas_available():
|
def is_pandas_available():
|
||||||
return importlib.util.find_spec("pandas") is not None
|
return _pandas_available
|
||||||
|
|
||||||
|
|
||||||
def is_sagemaker_dp_enabled():
|
def is_sagemaker_dp_enabled():
|
||||||
@@ -688,7 +580,7 @@ def is_sagemaker_dp_enabled():
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return False
|
return False
|
||||||
# Lastly, check if the `smdistributed` module is present.
|
# Lastly, check if the `smdistributed` module is present.
|
||||||
return importlib.util.find_spec("smdistributed") is not None
|
return _smdistributed_available
|
||||||
|
|
||||||
|
|
||||||
def is_sagemaker_mp_enabled():
|
def is_sagemaker_mp_enabled():
|
||||||
@@ -712,7 +604,7 @@ def is_sagemaker_mp_enabled():
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return False
|
return False
|
||||||
# Lastly, check if the `smdistributed` module is present.
|
# Lastly, check if the `smdistributed` module is present.
|
||||||
return importlib.util.find_spec("smdistributed") is not None
|
return _smdistributed_available
|
||||||
|
|
||||||
|
|
||||||
def is_training_run_on_sagemaker():
|
def is_training_run_on_sagemaker():
|
||||||
@@ -762,11 +654,11 @@ def is_ccl_available():
|
|||||||
|
|
||||||
|
|
||||||
def is_decord_available():
|
def is_decord_available():
|
||||||
return _decord_availale
|
return _decord_available
|
||||||
|
|
||||||
|
|
||||||
def is_sudachi_available():
|
def is_sudachi_available():
|
||||||
return importlib.util.find_spec("sudachipy") is not None
|
return _sudachipy_available
|
||||||
|
|
||||||
|
|
||||||
def is_jumanpp_available():
|
def is_jumanpp_available():
|
||||||
|
|||||||
@@ -319,12 +319,12 @@ class OnnxExportTestCaseV2(TestCase):
|
|||||||
onnx_config = onnx_config_class_constructor(model.config)
|
onnx_config = onnx_config_class_constructor(model.config)
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers.utils import torch_version
|
from transformers.utils import get_torch_version
|
||||||
|
|
||||||
if torch_version < onnx_config.torch_onnx_minimum_version:
|
if get_torch_version() < onnx_config.torch_onnx_minimum_version:
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"Skipping due to incompatible PyTorch version. Minimum required is"
|
"Skipping due to incompatible PyTorch version. Minimum required is"
|
||||||
f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
|
f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
preprocessor = get_preprocessor(model_name)
|
preprocessor = get_preprocessor(model_name)
|
||||||
@@ -362,12 +362,12 @@ class OnnxExportTestCaseV2(TestCase):
|
|||||||
onnx_config = onnx_config_class_constructor(model.config)
|
onnx_config = onnx_config_class_constructor(model.config)
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers.utils import torch_version
|
from transformers.utils import get_torch_version
|
||||||
|
|
||||||
if torch_version < onnx_config.torch_onnx_minimum_version:
|
if get_torch_version() < onnx_config.torch_onnx_minimum_version:
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"Skipping due to incompatible PyTorch version. Minimum required is"
|
"Skipping due to incompatible PyTorch version. Minimum required is"
|
||||||
f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
|
f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_model = model.get_encoder()
|
encoder_model = model.get_encoder()
|
||||||
|
|||||||
Reference in New Issue
Block a user