let encode accept tensor inputs

This commit is contained in:
thomwolf
2019-09-25 10:19:14 +02:00
parent e8e956dbb2
commit c4acc3a8e9
3 changed files with 48 additions and 17 deletions

View File

@@ -23,6 +23,20 @@ from botocore.exceptions import ClientError
import requests
from tqdm import tqdm
try:
import tensorflow as tf
assert int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name
except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name
try:
import torch
_torch_available = True # pylint: disable=invalid-name
except ImportError:
_torch_available = False # pylint: disable=invalid-name
try:
from torch.hub import _get_torch_home
torch_cache_home = _get_torch_home()
@@ -55,6 +69,12 @@ CONFIG_NAME = "config.json"
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def is_torch_available():
return _torch_available
def is_tf_available():
return _tf_available
if not six.PY2:
def add_start_docstrings(*docstr):
def docstring_decorator(fn):