fix tests
This commit is contained in:
@@ -23,16 +23,20 @@ from botocore.exceptions import ClientError
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
assert int(tf.__version__[0]) >= 2
|
||||
_tf_available = True # pylint: disable=invalid-name
|
||||
logger.info("TensorFlow version {} available.".format(tf.__version__))
|
||||
except (ImportError, AssertionError):
|
||||
_tf_available = False # pylint: disable=invalid-name
|
||||
|
||||
try:
|
||||
import torch
|
||||
_torch_available = True # pylint: disable=invalid-name
|
||||
logger.info("PyTorch version {} available.".format(torch.__version__))
|
||||
except ImportError:
|
||||
_torch_available = False # pylint: disable=invalid-name
|
||||
|
||||
@@ -67,8 +71,6 @@ TF2_WEIGHTS_NAME = 'tf_model.h5'
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
CONFIG_NAME = "config.json"
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
|
||||
Reference in New Issue
Block a user