From 78863f6b36e975f718eeae01a6cea3681ff735aa Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 25 Sep 2019 21:09:46 +0200 Subject: [PATCH] fix tokenizer to tensors --- pytorch_transformers/tokenization_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index f7c0e976ab..74797ea206 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -27,6 +27,8 @@ from .file_utils import cached_path, is_tf_available, is_torch_available if is_tf_available(): import tensorflow as tf +if is_torch_available() + import torch logger = logging.getLogger(__name__) @@ -849,7 +851,11 @@ class PreTrainedTokenizer(object): if return_tensors == 'tf' and is_tf_available(): sequence = tf.constant(sequence) token_type_ids = tf.constant(token_type_ids) - elif return_tensors = 'pt' and is + elif return_tensors == 'pt' and is_torch_available(): + sequence = torch.tensor(sequence) + token_type_ids = torch.tensor(token_type_ids) + elif return_tensors is not None: + logger.warning("Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(return_tensors)) encoded_inputs["input_ids"] = sequence encoded_inputs["token_type_ids"] = token_type_ids