Fast transformers import part 1 (#9441)

* Don't import libs to check they are available

* Don't import integrations at init

* Add importlib_metdata to deps

* Remove old vars references

* Avoid syntax error

* Adapt testing utils

* Try to appease torchhub

* Add dependency

* Remove more private variables

* Fix typo

* Another typo

* Refine the tf availability test
This commit is contained in:
Sylvain Gugger
2021-01-06 12:17:24 -05:00
committed by GitHub
parent c89f1bc92e
commit 0c96262f7d
13 changed files with 280 additions and 360 deletions

View File

@@ -18,6 +18,7 @@ https://github.com/allenai/allennlp.
import copy
import fnmatch
import importlib.util
import io
import json
import os
@@ -37,8 +38,10 @@ from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile
import numpy as np
from packaging import version
from tqdm.auto import tqdm
import importlib_metadata
import requests
from filelock import FileLock
@@ -52,195 +55,88 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
_torch_available = importlib.util.find_spec("torch") is not None
if _torch_available:
try:
_torch_version = importlib_metadata.version("torch")
logger.info(f"PyTorch version {_torch_version} available.")
except importlib_metadata.PackageNotFoundError:
_torch_available = False
else:
logger.info("Disabling PyTorch because USE_TF is set")
_torch_available = False
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
_tf_available = importlib.util.find_spec("tensorflow") is not None
if _tf_available:
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
try:
_tf_version = importlib_metadata.version("tensorflow")
except importlib_metadata.PackageNotFoundError:
try:
_tf_version = importlib_metadata.version("tensorflow-cpu")
except importlib_metadata.PackageNotFoundError:
_tf_version = None
_tf_available = False
if _tf_available:
if version.parse(_tf_version) < version.parse("2"):
logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.")
_tf_available = False
else:
logger.info(f"TensorFlow version {_tf_version} available.")
else:
logger.info("Disabling Tensorflow because USE_TORCH is set")
_tf_available = False
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
if _flax_available:
try:
_jax_version = importlib_metadata.version("jax")
_flax_version = importlib_metadata.version("flax")
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
except importlib_metadata.PackageNotFoundError:
_flax_available = False
else:
_flax_available = False
_datasets_available = importlib.util.find_spec("datasets") is not None
try:
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
import torch
_torch_available = True # pylint: disable=invalid-name
logger.info("PyTorch version {} available.".format(torch.__version__))
else:
logger.info("Disabling PyTorch because USE_TF is set")
_torch_available = False
except ImportError:
_torch_available = False # pylint: disable=invalid-name
try:
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
import tensorflow as tf
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name
logger.info("TensorFlow version {} available.".format(tf.__version__))
else:
logger.info("Disabling Tensorflow because USE_TORCH is set")
_tf_available = False
except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name
try:
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
import flax
import jax
logger.info("JAX version {}, Flax: available".format(jax.__version__))
logger.info("Flax available: {}".format(flax))
_flax_available = True
else:
_flax_available = False
except ImportError:
_flax_available = False # pylint: disable=invalid-name
try:
import datasets # noqa: F401
# Check we're not importing a "datasets" directory somewhere
_datasets_available = hasattr(datasets, "__version__") and hasattr(datasets, "load_dataset")
if _datasets_available:
logger.debug(f"Successfully imported datasets version {datasets.__version__}")
else:
logger.debug("Imported a datasets object but this doesn't seem to be the 🤗 datasets library.")
except ImportError:
# Check we're not importing a "datasets" directory somewhere but the actual library by trying to grab the version
# AND checking it has an author field in the metadata that is HuggingFace.
_ = importlib_metadata.version("datasets")
_datasets_metadata = importlib_metadata.metadata("datasets")
if _datasets_metadata.get("author", "") != "HuggingFace Inc.":
_datasets_available = False
except importlib_metadata.PackageNotFoundError:
_datasets_available = False
_faiss_available = importlib.util.find_spec("faiss") is not None
try:
from torch.hub import _get_torch_home
torch_cache_home = _get_torch_home()
except ImportError:
torch_cache_home = os.path.expanduser(
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
)
try:
import torch_xla.core.xla_model as xm # noqa: F401
if _torch_available:
_torch_tpu_available = True # pylint: disable=
else:
_torch_tpu_available = False
except ImportError:
_torch_tpu_available = False
try:
import psutil # noqa: F401
_psutil_available = True
except ImportError:
_psutil_available = False
try:
import py3nvml # noqa: F401
_py3nvml_available = True
except ImportError:
_py3nvml_available = False
try:
from apex import amp # noqa: F401
_has_apex = True
except ImportError:
_has_apex = False
try:
import faiss # noqa: F401
_faiss_available = True
logger.debug(f"Successfully imported faiss version {faiss.__version__}")
except ImportError:
_faiss_version = importlib_metadata.version("faiss")
logger.debug(f"Successfully imported faiss version {_faiss_version}")
except importlib_metadata.PackageNotFoundError:
_faiss_available = False
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
try:
import sklearn.metrics # noqa: F401
import scipy.stats # noqa: F401
_has_sklearn = True
except (AttributeError, ImportError):
_has_sklearn = False
try:
# Test copied from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
get_ipython = sys.modules["IPython"].get_ipython
if "IPKernelApp" not in get_ipython().config:
raise ImportError("console")
if "VSCODE_PID" in os.environ:
raise ImportError("vscode")
import IPython # noqa: F401
_in_notebook = True
except (AttributeError, ImportError, KeyError):
_in_notebook = False
try:
import sentencepiece # noqa: F401
_sentencepiece_available = True
except ImportError:
_sentencepiece_available = False
try:
import google.protobuf # noqa: F401
_protobuf_available = True
except ImportError:
_protobuf_available = False
try:
import tokenizers # noqa: F401
_tokenizers_available = True
except ImportError:
_tokenizers_available = False
try:
import pandas # noqa: F401
_pandas_available = True
except ImportError:
_pandas_available = False
try:
import torch_scatter
# Check we're not importing a "torch_scatter" directory somewhere
_scatter_available = hasattr(torch_scatter, "__version__") and hasattr(torch_scatter, "scatter")
if _scatter_available:
logger.debug(f"Succesfully imported torch-scatter version {torch_scatter.__version__}")
else:
logger.debug("Imported a torch_scatter object but this doesn't seem to be the torch-scatter library.")
except ImportError:
_scatter_version = importlib_metadata.version("torch_scatterr")
logger.debug(f"Successfully imported torch-scatter version {_scatter_version}")
except importlib_metadata.PackageNotFoundError:
_scatter_available = False
torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
# New default cache, shared with the Datasets library
hf_cache_home = os.path.expanduser(
@@ -308,7 +204,14 @@ def is_flax_available():
def is_torch_tpu_available():
return _torch_tpu_available
if not _torch_available:
return False
# This test is probably enough, but just in case, we unpack a bit.
if importlib.util.find_spec("torch_xla") is None:
return False
if importlib.util.find_spec("torch_xla.core") is None:
return False
return importlib.util.find_spec("torch_xla.core.xla_model") is not None
def is_datasets_available():
@@ -316,15 +219,15 @@ def is_datasets_available():
def is_psutil_available():
return _psutil_available
return importlib.util.find_spec("psutil") is not None
def is_py3nvml_available():
return _py3nvml_available
return importlib.util.find_spec("py3nvml") is not None
def is_apex_available():
return _has_apex
return importlib.util.find_spec("apex") is not None
def is_faiss_available():
@@ -332,23 +235,39 @@ def is_faiss_available():
def is_sklearn_available():
return _has_sklearn
if importlib.util.find_spec("sklearn") is None:
return False
if importlib.util.find_spec("scipy") is None:
return False
return importlib.util.find_spec("sklearn.metrics") and importlib.util.find_spec("scipy.stats")
def is_sentencepiece_available():
return _sentencepiece_available
return importlib.util.find_spec("sentencepiece") is not None
def is_protobuf_available():
return _protobuf_available
if importlib.util.find_spec("google") is None:
return False
return importlib.util.find_spec("google.protobuf") is not None
def is_tokenizers_available():
return _tokenizers_available
return importlib.util.find_spec("tokenizers") is not None
def is_in_notebook():
return _in_notebook
try:
# Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
get_ipython = sys.modules["IPython"].get_ipython
if "IPKernelApp" not in get_ipython().config:
raise ImportError("console")
if "VSCODE_PID" in os.environ:
raise ImportError("vscode")
return importlib.util.find_spec("IPython") is not None
except (AttributeError, ImportError, KeyError):
return False
def is_scatter_available():
@@ -356,7 +275,7 @@ def is_scatter_available():
def is_pandas_available():
return _pandas_available
return importlib.util.find_spec("pandas") is not None
def torch_only_method(fn):
@@ -1167,9 +1086,9 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
"""
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
if is_torch_available():
ua += "; torch/{}".format(torch.__version__)
ua += f"; torch/{_torch_version}"
if is_tf_available():
ua += "; tensorflow/{}".format(tf.__version__)
ua += f"; tensorflow/{_tf_version}"
if isinstance(user_agent, dict):
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
elif isinstance(user_agent, str):