Fast transformers import part 1 (#9441)

* Don't import libs to check they are available

* Don't import integrations at init

* Add importlib_metdata to deps

* Remove old vars references

* Avoid syntax error

* Adapt testing utils

* Try to appease torchhub

* Add dependency

* Remove more private variables

* Fix typo

* Another typo

* Refine the tf availability test
This commit is contained in:
Sylvain Gugger
2021-01-06 12:17:24 -05:00
committed by GitHub
parent c89f1bc92e
commit 0c96262f7d
13 changed files with 280 additions and 360 deletions

View File

@@ -25,18 +25,18 @@ from io import StringIO
from pathlib import Path
from .file_utils import (
_datasets_available,
_faiss_available,
_flax_available,
_pandas_available,
_scatter_available,
_sentencepiece_available,
_tf_available,
_tokenizers_available,
_torch_available,
_torch_tpu_available,
is_datasets_available,
is_faiss_available,
is_flax_available,
is_pandas_available,
is_scatter_available,
is_sentencepiece_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
is_torch_tpu_available,
)
from .integrations import _has_optuna, _has_ray
from .integrations import is_optuna_available, is_ray_available
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
@@ -90,7 +90,7 @@ def is_pt_tf_cross_test(test_case):
to a truthy value and selecting the is_pt_tf_cross_test pytest mark.
"""
if not _run_pt_tf_cross_tests or not _torch_available or not _tf_available:
if not _run_pt_tf_cross_tests or not is_torch_available() or not is_tf_available():
return unittest.skip("test is PT+TF test")(test_case)
else:
try:
@@ -166,7 +166,7 @@ def require_torch(test_case):
These tests are skipped when PyTorch isn't installed.
"""
if not _torch_available:
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
else:
return test_case
@@ -179,7 +179,7 @@ def require_torch_scatter(test_case):
These tests are skipped when PyTorch scatter isn't installed.
"""
if not _scatter_available:
if not is_scatter_available():
return unittest.skip("test requires PyTorch scatter")(test_case)
else:
return test_case
@@ -192,7 +192,7 @@ def require_tf(test_case):
These tests are skipped when TensorFlow isn't installed.
"""
if not _tf_available:
if not is_tf_available():
return unittest.skip("test requires TensorFlow")(test_case)
else:
return test_case
@@ -205,7 +205,7 @@ def require_flax(test_case):
These tests are skipped when one / both are not installed
"""
if not _flax_available:
if not is_flax_available():
test_case = unittest.skip("test requires JAX & Flax")(test_case)
return test_case
@@ -217,7 +217,7 @@ def require_sentencepiece(test_case):
These tests are skipped when SentencePiece isn't installed.
"""
if not _sentencepiece_available:
if not is_sentencepiece_available():
return unittest.skip("test requires SentencePiece")(test_case)
else:
return test_case
@@ -230,7 +230,7 @@ def require_tokenizers(test_case):
These tests are skipped when 🤗 Tokenizers isn't installed.
"""
if not _tokenizers_available:
if not is_tokenizers_available():
return unittest.skip("test requires tokenizers")(test_case)
else:
return test_case
@@ -240,7 +240,7 @@ def require_pandas(test_case):
"""
Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
"""
if not _pandas_available:
if not is_pandas_available():
return unittest.skip("test requires pandas")(test_case)
else:
return test_case
@@ -251,7 +251,7 @@ def require_scatter(test_case):
Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't
installed.
"""
if not _scatter_available:
if not is_scatter_available():
return unittest.skip("test requires PyTorch Scatter")(test_case)
else:
return test_case
@@ -265,7 +265,7 @@ def require_torch_multi_gpu(test_case):
To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
"""
if not _torch_available:
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
import torch
@@ -280,7 +280,7 @@ def require_torch_non_multi_gpu(test_case):
"""
Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
"""
if not _torch_available:
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
import torch
@@ -301,13 +301,13 @@ def require_torch_tpu(test_case):
"""
Decorator marking a test that requires a TPU (in PyTorch).
"""
if not _torch_tpu_available:
if not is_torch_tpu_available():
return unittest.skip("test requires PyTorch TPU")
else:
return test_case
if _torch_available:
if is_torch_available():
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
import torch
@@ -327,7 +327,7 @@ def require_torch_gpu(test_case):
def require_datasets(test_case):
"""Decorator marking a test that requires datasets."""
if not _datasets_available:
if not is_datasets_available():
return unittest.skip("test requires `datasets`")(test_case)
else:
return test_case
@@ -335,7 +335,7 @@ def require_datasets(test_case):
def require_faiss(test_case):
"""Decorator marking a test that requires faiss."""
if not _faiss_available:
if not is_faiss_available():
return unittest.skip("test requires `faiss`")(test_case)
else:
return test_case
@@ -348,7 +348,7 @@ def require_optuna(test_case):
These tests are skipped when optuna isn't installed.
"""
if not _has_optuna:
if not is_optuna_available():
return unittest.skip("test requires optuna")(test_case)
else:
return test_case
@@ -361,7 +361,7 @@ def require_ray(test_case):
These tests are skipped when Ray/tune isn't installed.
"""
if not _has_ray:
if not is_ray_available():
return unittest.skip("test requires Ray/tune")(test_case)
else:
return test_case
@@ -371,11 +371,11 @@ def get_gpu_count():
"""
Return the number of available gpus (regardless of whether torch or tf is used)
"""
if _torch_available:
if is_torch_available():
import torch
return torch.cuda.device_count()
elif _tf_available:
elif is_tf_available():
import tensorflow as tf
return len(tf.config.list_physical_devices("GPU"))