Fast transformers import part 1 (#9441)
* Don't import libs to check they are available * Don't import integrations at init * Add importlib_metdata to deps * Remove old vars references * Avoid syntax error * Adapt testing utils * Try to appease torchhub * Add dependency * Remove more private variables * Fix typo * Another typo * Refine the tf availability test
This commit is contained in:
@@ -18,6 +18,7 @@ https://github.com/allenai/allennlp.
|
||||
|
||||
import copy
|
||||
import fnmatch
|
||||
import importlib.util
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
@@ -37,8 +38,10 @@ from urllib.parse import urlparse
|
||||
from zipfile import ZipFile, is_zipfile
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import importlib_metadata
|
||||
import requests
|
||||
from filelock import FileLock
|
||||
|
||||
@@ -52,195 +55,88 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"}
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
||||
|
||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
||||
|
||||
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
||||
_torch_available = importlib.util.find_spec("torch") is not None
|
||||
if _torch_available:
|
||||
try:
|
||||
_torch_version = importlib_metadata.version("torch")
|
||||
logger.info(f"PyTorch version {_torch_version} available.")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_torch_available = False
|
||||
else:
|
||||
logger.info("Disabling PyTorch because USE_TF is set")
|
||||
_torch_available = False
|
||||
|
||||
|
||||
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
||||
_tf_available = importlib.util.find_spec("tensorflow") is not None
|
||||
if _tf_available:
|
||||
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
||||
try:
|
||||
_tf_version = importlib_metadata.version("tensorflow")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
try:
|
||||
_tf_version = importlib_metadata.version("tensorflow-cpu")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_tf_version = None
|
||||
_tf_available = False
|
||||
if _tf_available:
|
||||
if version.parse(_tf_version) < version.parse("2"):
|
||||
logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.")
|
||||
_tf_available = False
|
||||
else:
|
||||
logger.info(f"TensorFlow version {_tf_version} available.")
|
||||
else:
|
||||
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
||||
_tf_available = False
|
||||
|
||||
|
||||
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
|
||||
if _flax_available:
|
||||
try:
|
||||
_jax_version = importlib_metadata.version("jax")
|
||||
_flax_version = importlib_metadata.version("flax")
|
||||
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_flax_available = False
|
||||
else:
|
||||
_flax_available = False
|
||||
|
||||
|
||||
_datasets_available = importlib.util.find_spec("datasets") is not None
|
||||
try:
|
||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
||||
import torch
|
||||
|
||||
_torch_available = True # pylint: disable=invalid-name
|
||||
logger.info("PyTorch version {} available.".format(torch.__version__))
|
||||
else:
|
||||
logger.info("Disabling PyTorch because USE_TF is set")
|
||||
_torch_available = False
|
||||
except ImportError:
|
||||
_torch_available = False # pylint: disable=invalid-name
|
||||
|
||||
try:
|
||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||
|
||||
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
||||
import tensorflow as tf
|
||||
|
||||
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
|
||||
_tf_available = True # pylint: disable=invalid-name
|
||||
logger.info("TensorFlow version {} available.".format(tf.__version__))
|
||||
else:
|
||||
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
||||
_tf_available = False
|
||||
except (ImportError, AssertionError):
|
||||
_tf_available = False # pylint: disable=invalid-name
|
||||
|
||||
|
||||
try:
|
||||
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
||||
|
||||
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
import flax
|
||||
import jax
|
||||
|
||||
logger.info("JAX version {}, Flax: available".format(jax.__version__))
|
||||
logger.info("Flax available: {}".format(flax))
|
||||
_flax_available = True
|
||||
else:
|
||||
_flax_available = False
|
||||
except ImportError:
|
||||
_flax_available = False # pylint: disable=invalid-name
|
||||
|
||||
|
||||
try:
|
||||
import datasets # noqa: F401
|
||||
|
||||
# Check we're not importing a "datasets" directory somewhere
|
||||
_datasets_available = hasattr(datasets, "__version__") and hasattr(datasets, "load_dataset")
|
||||
if _datasets_available:
|
||||
logger.debug(f"Successfully imported datasets version {datasets.__version__}")
|
||||
else:
|
||||
logger.debug("Imported a datasets object but this doesn't seem to be the 🤗 datasets library.")
|
||||
|
||||
except ImportError:
|
||||
# Check we're not importing a "datasets" directory somewhere but the actual library by trying to grab the version
|
||||
# AND checking it has an author field in the metadata that is HuggingFace.
|
||||
_ = importlib_metadata.version("datasets")
|
||||
_datasets_metadata = importlib_metadata.metadata("datasets")
|
||||
if _datasets_metadata.get("author", "") != "HuggingFace Inc.":
|
||||
_datasets_available = False
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_datasets_available = False
|
||||
|
||||
|
||||
_faiss_available = importlib.util.find_spec("faiss") is not None
|
||||
try:
|
||||
from torch.hub import _get_torch_home
|
||||
|
||||
torch_cache_home = _get_torch_home()
|
||||
except ImportError:
|
||||
torch_cache_home = os.path.expanduser(
|
||||
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm # noqa: F401
|
||||
|
||||
if _torch_available:
|
||||
_torch_tpu_available = True # pylint: disable=
|
||||
else:
|
||||
_torch_tpu_available = False
|
||||
except ImportError:
|
||||
_torch_tpu_available = False
|
||||
|
||||
|
||||
try:
|
||||
import psutil # noqa: F401
|
||||
|
||||
_psutil_available = True
|
||||
|
||||
except ImportError:
|
||||
_psutil_available = False
|
||||
|
||||
|
||||
try:
|
||||
import py3nvml # noqa: F401
|
||||
|
||||
_py3nvml_available = True
|
||||
|
||||
except ImportError:
|
||||
_py3nvml_available = False
|
||||
|
||||
|
||||
try:
|
||||
from apex import amp # noqa: F401
|
||||
|
||||
_has_apex = True
|
||||
except ImportError:
|
||||
_has_apex = False
|
||||
|
||||
|
||||
try:
|
||||
import faiss # noqa: F401
|
||||
|
||||
_faiss_available = True
|
||||
logger.debug(f"Successfully imported faiss version {faiss.__version__}")
|
||||
except ImportError:
|
||||
_faiss_version = importlib_metadata.version("faiss")
|
||||
logger.debug(f"Successfully imported faiss version {_faiss_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_faiss_available = False
|
||||
|
||||
|
||||
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
|
||||
try:
|
||||
import sklearn.metrics # noqa: F401
|
||||
|
||||
import scipy.stats # noqa: F401
|
||||
|
||||
_has_sklearn = True
|
||||
except (AttributeError, ImportError):
|
||||
_has_sklearn = False
|
||||
|
||||
try:
|
||||
# Test copied from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
|
||||
get_ipython = sys.modules["IPython"].get_ipython
|
||||
if "IPKernelApp" not in get_ipython().config:
|
||||
raise ImportError("console")
|
||||
if "VSCODE_PID" in os.environ:
|
||||
raise ImportError("vscode")
|
||||
|
||||
import IPython # noqa: F401
|
||||
|
||||
_in_notebook = True
|
||||
except (AttributeError, ImportError, KeyError):
|
||||
_in_notebook = False
|
||||
|
||||
|
||||
try:
|
||||
import sentencepiece # noqa: F401
|
||||
|
||||
_sentencepiece_available = True
|
||||
|
||||
except ImportError:
|
||||
_sentencepiece_available = False
|
||||
|
||||
|
||||
try:
|
||||
import google.protobuf # noqa: F401
|
||||
|
||||
_protobuf_available = True
|
||||
|
||||
except ImportError:
|
||||
_protobuf_available = False
|
||||
|
||||
|
||||
try:
|
||||
import tokenizers # noqa: F401
|
||||
|
||||
_tokenizers_available = True
|
||||
|
||||
except ImportError:
|
||||
_tokenizers_available = False
|
||||
|
||||
|
||||
try:
|
||||
import pandas # noqa: F401
|
||||
|
||||
_pandas_available = True
|
||||
|
||||
except ImportError:
|
||||
_pandas_available = False
|
||||
|
||||
|
||||
try:
|
||||
import torch_scatter
|
||||
|
||||
# Check we're not importing a "torch_scatter" directory somewhere
|
||||
_scatter_available = hasattr(torch_scatter, "__version__") and hasattr(torch_scatter, "scatter")
|
||||
if _scatter_available:
|
||||
logger.debug(f"Succesfully imported torch-scatter version {torch_scatter.__version__}")
|
||||
else:
|
||||
logger.debug("Imported a torch_scatter object but this doesn't seem to be the torch-scatter library.")
|
||||
|
||||
except ImportError:
|
||||
_scatter_version = importlib_metadata.version("torch_scatterr")
|
||||
logger.debug(f"Successfully imported torch-scatter version {_scatter_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_scatter_available = False
|
||||
|
||||
|
||||
torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
|
||||
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
|
||||
# New default cache, shared with the Datasets library
|
||||
hf_cache_home = os.path.expanduser(
|
||||
@@ -308,7 +204,14 @@ def is_flax_available():
|
||||
|
||||
|
||||
def is_torch_tpu_available():
|
||||
return _torch_tpu_available
|
||||
if not _torch_available:
|
||||
return False
|
||||
# This test is probably enough, but just in case, we unpack a bit.
|
||||
if importlib.util.find_spec("torch_xla") is None:
|
||||
return False
|
||||
if importlib.util.find_spec("torch_xla.core") is None:
|
||||
return False
|
||||
return importlib.util.find_spec("torch_xla.core.xla_model") is not None
|
||||
|
||||
|
||||
def is_datasets_available():
|
||||
@@ -316,15 +219,15 @@ def is_datasets_available():
|
||||
|
||||
|
||||
def is_psutil_available():
|
||||
return _psutil_available
|
||||
return importlib.util.find_spec("psutil") is not None
|
||||
|
||||
|
||||
def is_py3nvml_available():
|
||||
return _py3nvml_available
|
||||
return importlib.util.find_spec("py3nvml") is not None
|
||||
|
||||
|
||||
def is_apex_available():
|
||||
return _has_apex
|
||||
return importlib.util.find_spec("apex") is not None
|
||||
|
||||
|
||||
def is_faiss_available():
|
||||
@@ -332,23 +235,39 @@ def is_faiss_available():
|
||||
|
||||
|
||||
def is_sklearn_available():
|
||||
return _has_sklearn
|
||||
if importlib.util.find_spec("sklearn") is None:
|
||||
return False
|
||||
if importlib.util.find_spec("scipy") is None:
|
||||
return False
|
||||
return importlib.util.find_spec("sklearn.metrics") and importlib.util.find_spec("scipy.stats")
|
||||
|
||||
|
||||
def is_sentencepiece_available():
|
||||
return _sentencepiece_available
|
||||
return importlib.util.find_spec("sentencepiece") is not None
|
||||
|
||||
|
||||
def is_protobuf_available():
|
||||
return _protobuf_available
|
||||
if importlib.util.find_spec("google") is None:
|
||||
return False
|
||||
return importlib.util.find_spec("google.protobuf") is not None
|
||||
|
||||
|
||||
def is_tokenizers_available():
|
||||
return _tokenizers_available
|
||||
return importlib.util.find_spec("tokenizers") is not None
|
||||
|
||||
|
||||
def is_in_notebook():
|
||||
return _in_notebook
|
||||
try:
|
||||
# Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
|
||||
get_ipython = sys.modules["IPython"].get_ipython
|
||||
if "IPKernelApp" not in get_ipython().config:
|
||||
raise ImportError("console")
|
||||
if "VSCODE_PID" in os.environ:
|
||||
raise ImportError("vscode")
|
||||
|
||||
return importlib.util.find_spec("IPython") is not None
|
||||
except (AttributeError, ImportError, KeyError):
|
||||
return False
|
||||
|
||||
|
||||
def is_scatter_available():
|
||||
@@ -356,7 +275,7 @@ def is_scatter_available():
|
||||
|
||||
|
||||
def is_pandas_available():
|
||||
return _pandas_available
|
||||
return importlib.util.find_spec("pandas") is not None
|
||||
|
||||
|
||||
def torch_only_method(fn):
|
||||
@@ -1167,9 +1086,9 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
||||
"""
|
||||
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
||||
if is_torch_available():
|
||||
ua += "; torch/{}".format(torch.__version__)
|
||||
ua += f"; torch/{_torch_version}"
|
||||
if is_tf_available():
|
||||
ua += "; tensorflow/{}".format(tf.__version__)
|
||||
ua += f"; tensorflow/{_tf_version}"
|
||||
if isinstance(user_agent, dict):
|
||||
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
||||
elif isinstance(user_agent, str):
|
||||
|
||||
Reference in New Issue
Block a user