Fast transformers import part 1 (#9441)
* Don't import libs to check they are available * Don't import integrations at init * Add importlib_metdata to deps * Remove old vars references * Avoid syntax error * Adapt testing utils * Try to appease torchhub * Add dependency * Remove more private variables * Fix typo * Another typo * Refine the tf availability test
This commit is contained in:
@@ -25,18 +25,18 @@ from io import StringIO
|
||||
from pathlib import Path
|
||||
|
||||
from .file_utils import (
|
||||
_datasets_available,
|
||||
_faiss_available,
|
||||
_flax_available,
|
||||
_pandas_available,
|
||||
_scatter_available,
|
||||
_sentencepiece_available,
|
||||
_tf_available,
|
||||
_tokenizers_available,
|
||||
_torch_available,
|
||||
_torch_tpu_available,
|
||||
is_datasets_available,
|
||||
is_faiss_available,
|
||||
is_flax_available,
|
||||
is_pandas_available,
|
||||
is_scatter_available,
|
||||
is_sentencepiece_available,
|
||||
is_tf_available,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
is_torch_tpu_available,
|
||||
)
|
||||
from .integrations import _has_optuna, _has_ray
|
||||
from .integrations import is_optuna_available, is_ray_available
|
||||
|
||||
|
||||
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
||||
@@ -90,7 +90,7 @@ def is_pt_tf_cross_test(test_case):
|
||||
to a truthy value and selecting the is_pt_tf_cross_test pytest mark.
|
||||
|
||||
"""
|
||||
if not _run_pt_tf_cross_tests or not _torch_available or not _tf_available:
|
||||
if not _run_pt_tf_cross_tests or not is_torch_available() or not is_tf_available():
|
||||
return unittest.skip("test is PT+TF test")(test_case)
|
||||
else:
|
||||
try:
|
||||
@@ -166,7 +166,7 @@ def require_torch(test_case):
|
||||
These tests are skipped when PyTorch isn't installed.
|
||||
|
||||
"""
|
||||
if not _torch_available:
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@@ -179,7 +179,7 @@ def require_torch_scatter(test_case):
|
||||
These tests are skipped when PyTorch scatter isn't installed.
|
||||
|
||||
"""
|
||||
if not _scatter_available:
|
||||
if not is_scatter_available():
|
||||
return unittest.skip("test requires PyTorch scatter")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@@ -192,7 +192,7 @@ def require_tf(test_case):
|
||||
These tests are skipped when TensorFlow isn't installed.
|
||||
|
||||
"""
|
||||
if not _tf_available:
|
||||
if not is_tf_available():
|
||||
return unittest.skip("test requires TensorFlow")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@@ -205,7 +205,7 @@ def require_flax(test_case):
|
||||
These tests are skipped when one / both are not installed
|
||||
|
||||
"""
|
||||
if not _flax_available:
|
||||
if not is_flax_available():
|
||||
test_case = unittest.skip("test requires JAX & Flax")(test_case)
|
||||
return test_case
|
||||
|
||||
@@ -217,7 +217,7 @@ def require_sentencepiece(test_case):
|
||||
These tests are skipped when SentencePiece isn't installed.
|
||||
|
||||
"""
|
||||
if not _sentencepiece_available:
|
||||
if not is_sentencepiece_available():
|
||||
return unittest.skip("test requires SentencePiece")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@@ -230,7 +230,7 @@ def require_tokenizers(test_case):
|
||||
These tests are skipped when 🤗 Tokenizers isn't installed.
|
||||
|
||||
"""
|
||||
if not _tokenizers_available:
|
||||
if not is_tokenizers_available():
|
||||
return unittest.skip("test requires tokenizers")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@@ -240,7 +240,7 @@ def require_pandas(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
|
||||
"""
|
||||
if not _pandas_available:
|
||||
if not is_pandas_available():
|
||||
return unittest.skip("test requires pandas")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@@ -251,7 +251,7 @@ def require_scatter(test_case):
|
||||
Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't
|
||||
installed.
|
||||
"""
|
||||
if not _scatter_available:
|
||||
if not is_scatter_available():
|
||||
return unittest.skip("test requires PyTorch Scatter")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@@ -265,7 +265,7 @@ def require_torch_multi_gpu(test_case):
|
||||
|
||||
To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
|
||||
"""
|
||||
if not _torch_available:
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
@@ -280,7 +280,7 @@ def require_torch_non_multi_gpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
|
||||
"""
|
||||
if not _torch_available:
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
@@ -301,13 +301,13 @@ def require_torch_tpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a TPU (in PyTorch).
|
||||
"""
|
||||
if not _torch_tpu_available:
|
||||
if not is_torch_tpu_available():
|
||||
return unittest.skip("test requires PyTorch TPU")
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
if _torch_available:
|
||||
if is_torch_available():
|
||||
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
|
||||
import torch
|
||||
|
||||
@@ -327,7 +327,7 @@ def require_torch_gpu(test_case):
|
||||
def require_datasets(test_case):
|
||||
"""Decorator marking a test that requires datasets."""
|
||||
|
||||
if not _datasets_available:
|
||||
if not is_datasets_available():
|
||||
return unittest.skip("test requires `datasets`")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@@ -335,7 +335,7 @@ def require_datasets(test_case):
|
||||
|
||||
def require_faiss(test_case):
|
||||
"""Decorator marking a test that requires faiss."""
|
||||
if not _faiss_available:
|
||||
if not is_faiss_available():
|
||||
return unittest.skip("test requires `faiss`")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@@ -348,7 +348,7 @@ def require_optuna(test_case):
|
||||
These tests are skipped when optuna isn't installed.
|
||||
|
||||
"""
|
||||
if not _has_optuna:
|
||||
if not is_optuna_available():
|
||||
return unittest.skip("test requires optuna")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@@ -361,7 +361,7 @@ def require_ray(test_case):
|
||||
These tests are skipped when Ray/tune isn't installed.
|
||||
|
||||
"""
|
||||
if not _has_ray:
|
||||
if not is_ray_available():
|
||||
return unittest.skip("test requires Ray/tune")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@@ -371,11 +371,11 @@ def get_gpu_count():
|
||||
"""
|
||||
Return the number of available gpus (regardless of whether torch or tf is used)
|
||||
"""
|
||||
if _torch_available:
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
return torch.cuda.device_count()
|
||||
elif _tf_available:
|
||||
elif is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
return len(tf.config.list_physical_devices("GPU"))
|
||||
|
||||
Reference in New Issue
Block a user