From 83eda6435e7c842e55b42a529e9bf367bf2a126b Mon Sep 17 00:00:00 2001 From: Alessandro Pietro Bardelli Date: Thu, 11 May 2023 19:52:22 +0200 Subject: [PATCH] Better check for packages availability (#23163) * Better check for packages availability * amend _optimumneuron_available * amend torch_version * amend PIL detection and lint * lint * amend _faiss_available * remove overloaded signatures of _is_package_available * fix sklearn and decord detection * remove unused checks * revert --- src/transformers/file_utils.py | 2 +- src/transformers/onnx/config.py | 4 +- src/transformers/onnx/convert.py | 4 +- src/transformers/utils/__init__.py | 2 +- src/transformers/utils/fx.py | 13 +- src/transformers/utils/import_utils.py | 404 +++++++++---------------- tests/onnx/test_onnx_v2.py | 12 +- 7 files changed, 168 insertions(+), 273 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index da24760118..38f4db0581 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -72,6 +72,7 @@ from .utils import ( get_cached_models, get_file_from_repo, get_full_repo_name, + get_torch_version, has_file, http_user_agent, is_apex_available, @@ -125,5 +126,4 @@ from .utils import ( to_numpy, to_py_obj, torch_only_method, - torch_version, ) diff --git a/src/transformers/onnx/config.py b/src/transformers/onnx/config.py index bbf06b0792..66236e9864 100644 --- a/src/transformers/onnx/config.py +++ b/src/transformers/onnx/config.py @@ -232,9 +232,9 @@ class OnnxConfig(ABC): `bool`: Whether the installed version of PyTorch is compatible with the model. """ if is_torch_available(): - from transformers.utils import torch_version + from transformers.utils import get_torch_version - return torch_version >= self.torch_onnx_minimum_version + return get_torch_version() >= self.torch_onnx_minimum_version else: return False diff --git a/src/transformers/onnx/convert.py b/src/transformers/onnx/convert.py index 9e9cc93c06..be46f7cd31 100644 --- a/src/transformers/onnx/convert.py +++ b/src/transformers/onnx/convert.py @@ -334,12 +334,12 @@ def export( preprocessor = tokenizer if is_torch_available(): - from ..utils import torch_version + from ..utils import get_torch_version if not config.is_torch_support_available: logger.warning( f"Unsupported PyTorch version for this model. Minimum required is {config.torch_onnx_minimum_version}," - f" got: {torch_version}" + f" got: {get_torch_version()}" ) if is_torch_available() and issubclass(type(model), PreTrainedModel): diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index a44a8360c6..35d3638aec 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -99,6 +99,7 @@ from .import_utils import ( _LazyModule, ccl_version, direct_transformers_import, + get_torch_version, is_accelerate_available, is_apex_available, is_bitsandbytes_available, @@ -170,7 +171,6 @@ from .import_utils import ( is_vision_available, requires_backends, torch_only_method, - torch_version, ) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index e82d44c802..5fb75cc4fa 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -25,7 +25,6 @@ import warnings from typing import Any, Callable, Dict, List, Optional, Type, Union import torch -from packaging import version from torch import nn from torch.fx import Graph, GraphModule, Proxy, Tracer from torch.fx._compatibility import compatibility @@ -54,8 +53,13 @@ from ..models.auto.modeling_auto import ( MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_MAPPING_NAMES, ) -from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_peft_available, is_torch_fx_available -from ..utils.versions import importlib_metadata +from ..utils import ( + ENV_VARS_TRUE_VALUES, + TORCH_FX_REQUIRED_VERSION, + get_torch_version, + is_peft_available, + is_torch_fx_available, +) if is_peft_available(): @@ -737,9 +741,8 @@ class HFTracer(Tracer): super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions) if not is_torch_fx_available(): - torch_version = version.parse(importlib_metadata.version("torch")) raise ImportError( - f"Found an incompatible version of torch. Found version {torch_version}, but only version " + f"Found an incompatible version of torch. Found version {get_torch_version()}, but only version " f"{TORCH_FX_REQUIRED_VERSION} is supported." ) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index af169332ec..b349f53827 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -25,7 +25,7 @@ from collections import OrderedDict from functools import lru_cache from itertools import chain from types import ModuleType -from typing import Any +from typing import Any, Tuple, Union from packaging import version @@ -35,6 +35,24 @@ from .versions import importlib_metadata logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: + # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version + package_exists = importlib.util.find_spec(pkg_name) is not None + package_version = "N/A" + if package_exists: + try: + package_version = importlib_metadata.version(pkg_name) + package_exists = True + except importlib_metadata.PackageNotFoundError: + package_exists = False + logger.debug(f"Detected {pkg_name} version {package_version}") + if return_version: + return package_exists, package_version + else: + return package_exists + + ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) @@ -44,26 +62,80 @@ USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper() +# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs. +TORCH_FX_REQUIRED_VERSION = version.parse("1.10") + + +_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) +_apex_available = _is_package_available("apex") +_bitsandbytes_available = _is_package_available("bitsandbytes") +_bs4_available = _is_package_available("bs4") +_coloredlogs_available = _is_package_available("coloredlogs") +_datasets_available = _is_package_available("datasets") +_decord_available = importlib.util.find_spec("decord") is not None +_detectron2_available = _is_package_available("detectron2") +_faiss_available = _is_package_available("faiss") or _is_package_available("faiss-cpu") +_ftfy_available = _is_package_available("ftfy") +_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True) +_jieba_available = _is_package_available("jieba") +_kenlm_available = _is_package_available("kenlm") +_keras_nlp_available = _is_package_available("keras_nlp") +_librosa_available = _is_package_available("librosa") +_natten_available = _is_package_available("natten") +_ninja_available = _is_package_available("ninja") +_onnx_available = _is_package_available("onnx") +_openai_available = _is_package_available("openai") +_optimum_available = _is_package_available("optimum") +_optimumneuron_available = _optimum_available and _is_package_available("optimum.neuron") +_pandas_available = _is_package_available("pandas") +_peft_available = _is_package_available("peft") +_phonemizer_available = _is_package_available("phonemizer") +_psutil_available = _is_package_available("psutil") +_py3nvml_available = _is_package_available("py3nvml") +_pyctcdecode_available = _is_package_available("pyctcdecode") +_pytesseract_available = _is_package_available("pytesseract") +_pytorch_quantization_available = _is_package_available("pytorch_quantization") +_rjieba_available = _is_package_available("rjieba") +_sacremoses_available = _is_package_available("sacremoses") +_safetensors_available = _is_package_available("safetensors") +_scipy_available = _is_package_available("scipy") +_sentencepiece_available = _is_package_available("sentencepiece") +_sklearn_available = importlib.util.find_spec("sklearn") is not None +if _sklearn_available: + try: + importlib_metadata.version("scikit-learn") + except importlib_metadata.PackageNotFoundError: + _sklearn_available = False +_smdistributed_available = _is_package_available("smdistributed") +_soundfile_available = _is_package_available("soundfile") +_spacy_available = _is_package_available("spacy") +_sudachipy_available = _is_package_available("sudachipy") +_tensorflow_probability_available = _is_package_available("tensorflow_probability") +_tensorflow_text_available = _is_package_available("tensorflow_text") +_tf2onnx_available = _is_package_available("tf2onnx") +_timm_available = _is_package_available("timm") +_tokenizers_available = _is_package_available("tokenizers") +_torchaudio_available = _is_package_available("torchaudio") +_torchdistx_available = _is_package_available("torchdistx") +_torchvision_available = _is_package_available("torchvision") + + _torch_version = "N/A" +_torch_available = False if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: - _torch_available = importlib.util.find_spec("torch") is not None - if _torch_available: - try: - _torch_version = importlib_metadata.version("torch") - logger.info(f"PyTorch version {_torch_version} available.") - except importlib_metadata.PackageNotFoundError: - _torch_available = False + _torch_available, _torch_version = _is_package_available("torch", return_version=True) else: logger.info("Disabling PyTorch because USE_TF is set") _torch_available = False _tf_version = "N/A" +_tf_available = False if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES: _tf_available = True else: if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: - _tf_available = importlib.util.find_spec("tensorflow") is not None + _tf_available = _is_package_available("tensorflow") if _tf_available: candidates = ( "tensorflow", @@ -93,180 +165,10 @@ else: f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum." ) _tf_available = False - else: - logger.info(f"TensorFlow version {_tf_version} available.") else: logger.info("Disabling Tensorflow because USE_TORCH is set") - _tf_available = False -if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: - _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None - if _flax_available: - try: - _jax_version = importlib_metadata.version("jax") - _flax_version = importlib_metadata.version("flax") - logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") - except importlib_metadata.PackageNotFoundError: - _flax_available = False -else: - _flax_available = False - - -_datasets_available = importlib.util.find_spec("datasets") is not None -try: - # Check we're not importing a "datasets" directory somewhere but the actual library by trying to grab the version - # AND checking it has an author field in the metadata that is HuggingFace. - _ = importlib_metadata.version("datasets") - _datasets_metadata = importlib_metadata.metadata("datasets") - if _datasets_metadata.get("author", "") != "HuggingFace Inc.": - _datasets_available = False -except importlib_metadata.PackageNotFoundError: - _datasets_available = False - - -_diffusers_available = importlib.util.find_spec("diffusers") is not None -try: - _diffusers_version = importlib_metadata.version("diffusers") - logger.debug(f"Successfully imported diffusers version {_diffusers_version}") -except importlib_metadata.PackageNotFoundError: - _diffusers_available = False - - -_detectron2_available = importlib.util.find_spec("detectron2") is not None -try: - _detectron2_version = importlib_metadata.version("detectron2") - logger.debug(f"Successfully imported detectron2 version {_detectron2_version}") -except importlib_metadata.PackageNotFoundError: - _detectron2_available = False - - -_faiss_available = importlib.util.find_spec("faiss") is not None -try: - _faiss_version = importlib_metadata.version("faiss") - logger.debug(f"Successfully imported faiss version {_faiss_version}") -except importlib_metadata.PackageNotFoundError: - try: - _faiss_version = importlib_metadata.version("faiss-cpu") - logger.debug(f"Successfully imported faiss version {_faiss_version}") - except importlib_metadata.PackageNotFoundError: - _faiss_available = False - - -_ftfy_available = importlib.util.find_spec("ftfy") is not None -try: - _ftfy_version = importlib_metadata.version("ftfy") - logger.debug(f"Successfully imported ftfy version {_ftfy_version}") -except importlib_metadata.PackageNotFoundError: - _ftfy_available = False - - -coloredlogs = importlib.util.find_spec("coloredlogs") is not None -try: - _coloredlogs_available = importlib_metadata.version("coloredlogs") - logger.debug(f"Successfully imported sympy version {_coloredlogs_available}") -except importlib_metadata.PackageNotFoundError: - _coloredlogs_available = False - - -sympy_available = importlib.util.find_spec("sympy") is not None -try: - _sympy_available = importlib_metadata.version("sympy") - logger.debug(f"Successfully imported sympy version {_sympy_available}") -except importlib_metadata.PackageNotFoundError: - _sympy_available = False - - -_tf2onnx_available = importlib.util.find_spec("tf2onnx") is not None -try: - _tf2onnx_version = importlib_metadata.version("tf2onnx") - logger.debug(f"Successfully imported tf2onnx version {_tf2onnx_version}") -except importlib_metadata.PackageNotFoundError: - _tf2onnx_available = False - - -_onnx_available = importlib.util.find_spec("onnxruntime") is not None -try: - _onxx_version = importlib_metadata.version("onnx") - logger.debug(f"Successfully imported onnx version {_onxx_version}") -except importlib_metadata.PackageNotFoundError: - _onnx_available = False - - -_opencv_available = importlib.util.find_spec("cv2") is not None - - -_pytorch_quantization_available = importlib.util.find_spec("pytorch_quantization") is not None -try: - _pytorch_quantization_version = importlib_metadata.version("pytorch_quantization") - logger.debug(f"Successfully imported pytorch-quantization version {_pytorch_quantization_version}") -except importlib_metadata.PackageNotFoundError: - _pytorch_quantization_available = False - - -_soundfile_available = importlib.util.find_spec("soundfile") is not None -try: - _soundfile_version = importlib_metadata.version("soundfile") - logger.debug(f"Successfully imported soundfile version {_soundfile_version}") -except importlib_metadata.PackageNotFoundError: - _soundfile_available = False - - -_tensorflow_probability_available = importlib.util.find_spec("tensorflow_probability") is not None -try: - _tensorflow_probability_version = importlib_metadata.version("tensorflow_probability") - logger.debug(f"Successfully imported tensorflow-probability version {_tensorflow_probability_version}") -except importlib_metadata.PackageNotFoundError: - _tensorflow_probability_available = False - - -_timm_available = importlib.util.find_spec("timm") is not None -try: - _timm_version = importlib_metadata.version("timm") - logger.debug(f"Successfully imported timm version {_timm_version}") -except importlib_metadata.PackageNotFoundError: - _timm_available = False - - -_natten_available = importlib.util.find_spec("natten") is not None -try: - _natten_version = importlib_metadata.version("natten") - logger.debug(f"Successfully imported natten version {_natten_version}") -except importlib_metadata.PackageNotFoundError: - _natten_available = False - - -_torchaudio_available = importlib.util.find_spec("torchaudio") is not None -try: - _torchaudio_version = importlib_metadata.version("torchaudio") - logger.debug(f"Successfully imported torchaudio version {_torchaudio_version}") -except importlib_metadata.PackageNotFoundError: - _torchaudio_available = False - - -_phonemizer_available = importlib.util.find_spec("phonemizer") is not None -try: - _phonemizer_version = importlib_metadata.version("phonemizer") - logger.debug(f"Successfully imported phonemizer version {_phonemizer_version}") -except importlib_metadata.PackageNotFoundError: - _phonemizer_available = False - - -_pyctcdecode_available = importlib.util.find_spec("pyctcdecode") is not None -try: - _pyctcdecode_version = importlib_metadata.version("pyctcdecode") - logger.debug(f"Successfully imported pyctcdecode version {_pyctcdecode_version}") -except importlib_metadata.PackageNotFoundError: - _pyctcdecode_available = False - - -_librosa_available = importlib.util.find_spec("librosa") is not None -try: - _librosa_version = importlib_metadata.version("librosa") - logger.debug(f"Successfully imported librosa version {_librosa_version}") -except importlib_metadata.PackageNotFoundError: - _librosa_available = False - ccl_version = "N/A" _is_ccl_available = ( importlib.util.find_spec("torch_ccl") is not None @@ -274,38 +176,46 @@ _is_ccl_available = ( ) try: ccl_version = importlib_metadata.version("oneccl_bind_pt") - logger.debug(f"Successfully imported oneccl_bind_pt version {ccl_version}") + logger.debug(f"Detected oneccl_bind_pt version {ccl_version}") except importlib_metadata.PackageNotFoundError: _is_ccl_available = False -_decord_availale = importlib.util.find_spec("decord") is not None -try: - _decord_version = importlib_metadata.version("decord") - logger.debug(f"Successfully imported decord version {_decord_version}") -except importlib_metadata.PackageNotFoundError: - _decord_availale = False -_jieba_available = importlib.util.find_spec("jieba") is not None -try: - _jieba_version = importlib_metadata.version("jieba") - logger.debug(f"Successfully imported jieba version {_jieba_version}") -except importlib_metadata.PackageNotFoundError: - _jieba_available = False +_flax_available = False +if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: + _flax_available, _flax_version = _is_package_available("flax", return_version=True) + if _flax_available: + _jax_available, _jax_version = _is_package_available("jax", return_version=True) + if _jax_available: + logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") + else: + _flax_available = _jax_available = False + _jax_version = _flax_version = "N/A" -# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs. -TORCH_FX_REQUIRED_VERSION = version.parse("1.10") + +_torch_fx_available = False +if _torch_available: + torch_version = version.parse(_torch_version) + _torch_fx_available = (torch_version.major, torch_version.minor) >= ( + TORCH_FX_REQUIRED_VERSION.major, + TORCH_FX_REQUIRED_VERSION.minor, + ) def is_kenlm_available(): - return importlib.util.find_spec("kenlm") is not None + return _kenlm_available def is_torch_available(): return _torch_available +def get_torch_version(): + return _torch_version + + def is_torchvision_available(): - return importlib.util.find_spec("torchvision") is not None + return _torchvision_available def is_pyctcdecode_available(): @@ -404,26 +314,16 @@ def is_torch_tf32_available(): return True -torch_version = None -_torch_fx_available = False -if _torch_available: - torch_version = version.parse(importlib_metadata.version("torch")) - _torch_fx_available = (torch_version.major, torch_version.minor) >= ( - TORCH_FX_REQUIRED_VERSION.major, - TORCH_FX_REQUIRED_VERSION.minor, - ) - - def is_torch_fx_available(): return _torch_fx_available def is_peft_available(): - return importlib.util.find_spec("peft") is not None + return _peft_available def is_bs4_available(): - return importlib.util.find_spec("bs4") is not None + return _bs4_available def is_tf_available(): @@ -443,7 +343,7 @@ def is_onnx_available(): def is_openai_available(): - return importlib.util.find_spec("openai") is not None + return _openai_available def is_flax_available(): @@ -517,40 +417,36 @@ def is_detectron2_available(): def is_rjieba_available(): - return importlib.util.find_spec("rjieba") is not None + return _rjieba_available def is_psutil_available(): - return importlib.util.find_spec("psutil") is not None + return _psutil_available def is_py3nvml_available(): - return importlib.util.find_spec("py3nvml") is not None + return _py3nvml_available def is_sacremoses_available(): - return importlib.util.find_spec("sacremoses") is not None + return _sacremoses_available def is_apex_available(): - return importlib.util.find_spec("apex") is not None + return _apex_available def is_ninja_available(): - return importlib.util.find_spec("ninja") is not None + return _ninja_available def is_ipex_available(): def get_major_and_minor_from_version(full_version): return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor) - if not is_torch_available() or importlib.util.find_spec("intel_extension_for_pytorch") is None: - return False - _ipex_version = "N/A" - try: - _ipex_version = importlib_metadata.version("intel_extension_for_pytorch") - except importlib_metadata.PackageNotFoundError: + if not is_torch_available() or not _ipex_available: return False + torch_major_and_minor = get_major_and_minor_from_version(_torch_version) ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) if torch_major_and_minor != ipex_major_and_minor: @@ -563,11 +459,11 @@ def is_ipex_available(): def is_bitsandbytes_available(): - return importlib.util.find_spec("bitsandbytes") is not None + return _bitsandbytes_available def is_torchdistx_available(): - return importlib.util.find_spec("torchdistx") is not None + return _torchdistx_available def is_faiss_available(): @@ -575,17 +471,15 @@ def is_faiss_available(): def is_scipy_available(): - return importlib.util.find_spec("scipy") is not None + return _scipy_available def is_sklearn_available(): - if importlib.util.find_spec("sklearn") is None: - return False - return is_scipy_available() and importlib.util.find_spec("sklearn.metrics") + return _sklearn_available def is_sentencepiece_available(): - return importlib.util.find_spec("sentencepiece") is not None + return _sentencepiece_available def is_protobuf_available(): @@ -595,56 +489,54 @@ def is_protobuf_available(): def is_accelerate_available(check_partial_state=False): - accelerate_available = importlib.util.find_spec("accelerate") is not None - if accelerate_available: - if check_partial_state: - return version.parse(importlib_metadata.version("accelerate")) >= version.parse("0.17.0") - else: - return True - else: - return False + if check_partial_state: + return _accelerate_available and version.parse(_accelerate_version) >= version.parse("0.17.0") + return _accelerate_available def is_optimum_available(): - return importlib.util.find_spec("optimum") is not None + return _optimum_available def is_optimum_neuron_available(): - return importlib.util.find_spec("optimum.neuron") is not None + return _optimumneuron_available def is_safetensors_available(): - if is_torch_available(): - if version.parse(_torch_version) >= version.parse("1.10"): - return importlib.util.find_spec("safetensors") is not None - else: - return False - else: - return importlib.util.find_spec("safetensors") is not None + if is_torch_available() and version.parse(_torch_version) < version.parse("1.10"): + return False + return _safetensors_available def is_tokenizers_available(): - return importlib.util.find_spec("tokenizers") is not None + return _tokenizers_available def is_vision_available(): - return importlib.util.find_spec("PIL") is not None + _pil_available = importlib.util.find_spec("PIL") is not None + if _pil_available: + try: + package_version = importlib_metadata.version("Pillow") + except importlib_metadata.PackageNotFoundError: + return False + logger.debug(f"Detected PIL version {package_version}") + return _pil_available def is_pytesseract_available(): - return importlib.util.find_spec("pytesseract") is not None + return _pytesseract_available def is_spacy_available(): - return importlib.util.find_spec("spacy") is not None + return _spacy_available def is_tensorflow_text_available(): - return is_tf_available() and importlib.util.find_spec("tensorflow_text") is not None + return is_tf_available() and _tensorflow_text_available def is_keras_nlp_available(): - return is_tensorflow_text_available() and importlib.util.find_spec("keras_nlp") is not None + return is_tensorflow_text_available() and _keras_nlp_available def is_in_notebook(): @@ -674,7 +566,7 @@ def is_tensorflow_probability_available(): def is_pandas_available(): - return importlib.util.find_spec("pandas") is not None + return _pandas_available def is_sagemaker_dp_enabled(): @@ -688,7 +580,7 @@ def is_sagemaker_dp_enabled(): except json.JSONDecodeError: return False # Lastly, check if the `smdistributed` module is present. - return importlib.util.find_spec("smdistributed") is not None + return _smdistributed_available def is_sagemaker_mp_enabled(): @@ -712,7 +604,7 @@ def is_sagemaker_mp_enabled(): except json.JSONDecodeError: return False # Lastly, check if the `smdistributed` module is present. - return importlib.util.find_spec("smdistributed") is not None + return _smdistributed_available def is_training_run_on_sagemaker(): @@ -762,11 +654,11 @@ def is_ccl_available(): def is_decord_available(): - return _decord_availale + return _decord_available def is_sudachi_available(): - return importlib.util.find_spec("sudachipy") is not None + return _sudachipy_available def is_jumanpp_available(): diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 81e28a3796..796fa1b3ea 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -319,12 +319,12 @@ class OnnxExportTestCaseV2(TestCase): onnx_config = onnx_config_class_constructor(model.config) if is_torch_available(): - from transformers.utils import torch_version + from transformers.utils import get_torch_version - if torch_version < onnx_config.torch_onnx_minimum_version: + if get_torch_version() < onnx_config.torch_onnx_minimum_version: pytest.skip( "Skipping due to incompatible PyTorch version. Minimum required is" - f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}" + f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}" ) preprocessor = get_preprocessor(model_name) @@ -362,12 +362,12 @@ class OnnxExportTestCaseV2(TestCase): onnx_config = onnx_config_class_constructor(model.config) if is_torch_available(): - from transformers.utils import torch_version + from transformers.utils import get_torch_version - if torch_version < onnx_config.torch_onnx_minimum_version: + if get_torch_version() < onnx_config.torch_onnx_minimum_version: pytest.skip( "Skipping due to incompatible PyTorch version. Minimum required is" - f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}" + f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}" ) encoder_model = model.get_encoder()