From c4acc3a8e96d7cb1d69b72899c4e730719cfe498 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 25 Sep 2019 10:19:14 +0200 Subject: [PATCH] let encode accept tensor inputs --- pytorch_transformers/__init__.py | 9 ++---- pytorch_transformers/file_utils.py | 20 ++++++++++++ pytorch_transformers/tokenization_utils.py | 36 ++++++++++++++++------ 3 files changed, 48 insertions(+), 17 deletions(-) diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index 508d0f84c4..5efbece795 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -163,10 +163,5 @@ if _tf_available and _torch_available: # Files and general utilities from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, cached_path, add_start_docstrings, add_end_docstrings, - WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME) - -def is_torch_available(): - return _torch_available - -def is_tf_available(): - return _tf_available + WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME, + is_tf_available, is_torch_available) \ No newline at end of file diff --git a/pytorch_transformers/file_utils.py b/pytorch_transformers/file_utils.py index 34333aaafb..2c761ef51d 100644 --- a/pytorch_transformers/file_utils.py +++ b/pytorch_transformers/file_utils.py @@ -23,6 +23,20 @@ from botocore.exceptions import ClientError import requests from tqdm import tqdm +try: + import tensorflow as tf + assert int(tf.__version__[0]) >= 2 + _tf_available = True # pylint: disable=invalid-name +except (ImportError, AssertionError): + _tf_available = False # pylint: disable=invalid-name + +try: + import torch + _torch_available = True # pylint: disable=invalid-name +except ImportError: + _torch_available = False # pylint: disable=invalid-name + + try: from torch.hub import _get_torch_home torch_cache_home = _get_torch_home() @@ -55,6 +69,12 @@ CONFIG_NAME = "config.json" logger = logging.getLogger(__name__) # pylint: disable=invalid-name +def is_torch_available(): + return _torch_available + +def is_tf_available(): + return _tf_available + if not six.PY2: def add_start_docstrings(*docstr): def docstring_decorator(fn): diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 53b8d245b8..5a307c5979 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -23,7 +23,10 @@ import six import copy from io import open -from .file_utils import cached_path +from .file_utils import cached_path, is_tf_available + +if is_tf_available(): + import tensorflow as tf logger = logging.getLogger(__name__) @@ -686,19 +689,32 @@ class PreTrainedTokenizer(object): to their model. **kwargs: passed to the `self.tokenize()` method """ + if is_tf_available(): + is_tf_tensor = False + if isinstance(text, tf.Tensor): + text = text.numpy() + is_tf_tensor = True + if isinstance(text, bytes): + text = text.decode('utf-8') + if text_pair is None: if add_special_tokens: - return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs))) + output = self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs))) else: - return self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) - - first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] - second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] - - if add_special_tokens: - return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens) + output = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) else: - return first_sentence_tokens, second_sentence_tokens + first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] + second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] + + if add_special_tokens: + output = self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens) + else: + output = first_sentence_tokens, second_sentence_tokens + + if is_tf_available() and is_tf_tensor: + output = tf.constant(output) + + return output def add_special_tokens_single_sentence(self, token_ids): logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.")