let encode accept tensor inputs
This commit is contained in:
@@ -23,6 +23,20 @@ from botocore.exceptions import ClientError
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
assert int(tf.__version__[0]) >= 2
|
||||
_tf_available = True # pylint: disable=invalid-name
|
||||
except (ImportError, AssertionError):
|
||||
_tf_available = False # pylint: disable=invalid-name
|
||||
|
||||
try:
|
||||
import torch
|
||||
_torch_available = True # pylint: disable=invalid-name
|
||||
except ImportError:
|
||||
_torch_available = False # pylint: disable=invalid-name
|
||||
|
||||
|
||||
try:
|
||||
from torch.hub import _get_torch_home
|
||||
torch_cache_home = _get_torch_home()
|
||||
@@ -55,6 +69,12 @@ CONFIG_NAME = "config.json"
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
def is_tf_available():
|
||||
return _tf_available
|
||||
|
||||
if not six.PY2:
|
||||
def add_start_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
|
||||
Reference in New Issue
Block a user