let encode accept tensor inputs
This commit is contained in:
@@ -23,7 +23,10 @@ import six
|
||||
import copy
|
||||
from io import open
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .file_utils import cached_path, is_tf_available
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -686,19 +689,32 @@ class PreTrainedTokenizer(object):
|
||||
to their model.
|
||||
**kwargs: passed to the `self.tokenize()` method
|
||||
"""
|
||||
if is_tf_available():
|
||||
is_tf_tensor = False
|
||||
if isinstance(text, tf.Tensor):
|
||||
text = text.numpy()
|
||||
is_tf_tensor = True
|
||||
if isinstance(text, bytes):
|
||||
text = text.decode('utf-8')
|
||||
|
||||
if text_pair is None:
|
||||
if add_special_tokens:
|
||||
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs)))
|
||||
output = self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs)))
|
||||
else:
|
||||
return self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||
|
||||
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
|
||||
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
|
||||
|
||||
if add_special_tokens:
|
||||
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
|
||||
output = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||
else:
|
||||
return first_sentence_tokens, second_sentence_tokens
|
||||
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
|
||||
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
|
||||
|
||||
if add_special_tokens:
|
||||
output = self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
|
||||
else:
|
||||
output = first_sentence_tokens, second_sentence_tokens
|
||||
|
||||
if is_tf_available() and is_tf_tensor:
|
||||
output = tf.constant(output)
|
||||
|
||||
return output
|
||||
|
||||
def add_special_tokens_single_sentence(self, token_ids):
|
||||
logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.")
|
||||
|
||||
Reference in New Issue
Block a user