let encode accept tensor inputs
This commit is contained in:
@@ -163,10 +163,5 @@ if _tf_available and _torch_available:
|
|||||||
# Files and general utilities
|
# Files and general utilities
|
||||||
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
||||||
cached_path, add_start_docstrings, add_end_docstrings,
|
cached_path, add_start_docstrings, add_end_docstrings,
|
||||||
WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME)
|
WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME,
|
||||||
|
is_tf_available, is_torch_available)
|
||||||
def is_torch_available():
|
|
||||||
return _torch_available
|
|
||||||
|
|
||||||
def is_tf_available():
|
|
||||||
return _tf_available
|
|
||||||
@@ -23,6 +23,20 @@ from botocore.exceptions import ClientError
|
|||||||
import requests
|
import requests
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tensorflow as tf
|
||||||
|
assert int(tf.__version__[0]) >= 2
|
||||||
|
_tf_available = True # pylint: disable=invalid-name
|
||||||
|
except (ImportError, AssertionError):
|
||||||
|
_tf_available = False # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
_torch_available = True # pylint: disable=invalid-name
|
||||||
|
except ImportError:
|
||||||
|
_torch_available = False # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.hub import _get_torch_home
|
from torch.hub import _get_torch_home
|
||||||
torch_cache_home = _get_torch_home()
|
torch_cache_home = _get_torch_home()
|
||||||
@@ -55,6 +69,12 @@ CONFIG_NAME = "config.json"
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
def is_torch_available():
|
||||||
|
return _torch_available
|
||||||
|
|
||||||
|
def is_tf_available():
|
||||||
|
return _tf_available
|
||||||
|
|
||||||
if not six.PY2:
|
if not six.PY2:
|
||||||
def add_start_docstrings(*docstr):
|
def add_start_docstrings(*docstr):
|
||||||
def docstring_decorator(fn):
|
def docstring_decorator(fn):
|
||||||
|
|||||||
@@ -23,7 +23,10 @@ import six
|
|||||||
import copy
|
import copy
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path, is_tf_available
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -686,19 +689,32 @@ class PreTrainedTokenizer(object):
|
|||||||
to their model.
|
to their model.
|
||||||
**kwargs: passed to the `self.tokenize()` method
|
**kwargs: passed to the `self.tokenize()` method
|
||||||
"""
|
"""
|
||||||
|
if is_tf_available():
|
||||||
|
is_tf_tensor = False
|
||||||
|
if isinstance(text, tf.Tensor):
|
||||||
|
text = text.numpy()
|
||||||
|
is_tf_tensor = True
|
||||||
|
if isinstance(text, bytes):
|
||||||
|
text = text.decode('utf-8')
|
||||||
|
|
||||||
if text_pair is None:
|
if text_pair is None:
|
||||||
if add_special_tokens:
|
if add_special_tokens:
|
||||||
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs)))
|
output = self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs)))
|
||||||
else:
|
else:
|
||||||
return self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
output = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||||
|
|
||||||
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
|
|
||||||
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
|
|
||||||
|
|
||||||
if add_special_tokens:
|
|
||||||
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
|
|
||||||
else:
|
else:
|
||||||
return first_sentence_tokens, second_sentence_tokens
|
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
|
||||||
|
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
|
||||||
|
|
||||||
|
if add_special_tokens:
|
||||||
|
output = self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
|
||||||
|
else:
|
||||||
|
output = first_sentence_tokens, second_sentence_tokens
|
||||||
|
|
||||||
|
if is_tf_available() and is_tf_tensor:
|
||||||
|
output = tf.constant(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
def add_special_tokens_single_sentence(self, token_ids):
|
def add_special_tokens_single_sentence(self, token_ids):
|
||||||
logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.")
|
logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.")
|
||||||
|
|||||||
Reference in New Issue
Block a user