From 0ea82b246f4a587295939b3621ce78b3d8e2ee60 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Tue, 24 Sep 2019 07:10:09 -0400 Subject: [PATCH] Updated tests --- .../tests/tokenization_tests_commons.py | 11 ++++++++--- pytorch_transformers/tokenization_utils.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index 1f84d36e7d..323e558310 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -264,9 +264,14 @@ class CommonTestCases: assert len(truncated_sequence) == len(sequence) - 2 assert truncated_sequence == truncated_second_sequence - def test_tokens_sent_to_encode(self): + def test_encode_input_type(self): tokenizer = self.get_tokenizer() sequence = "Let's encode this sequence" - tokens = tokenizer.encode(sequence) - tokenizer.encode(tokens, add_special_tokens=True) + + tokens = tokenizer.tokenize(sequence) + input_ids = tokenizer.convert_tokens_to_ids(tokens) + formatted_input = tokenizer.encode(sequence, add_special_tokens=True) + + assert tokenizer.encode(tokens, add_special_tokens=True) == formatted_input + assert tokenizer.encode(input_ids, add_special_tokens=True) == formatted_input diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 1209c60de5..478ba6da87 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -744,7 +744,7 @@ class PreTrainedTokenizer(object): def get_input_ids(text): if isinstance(text, six.string_types): input_ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) - elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str): + elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], six.string_types): input_ids = self.convert_tokens_to_ids(text) elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): input_ids = text