diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index 5da19bb660..8a3b56a058 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -263,3 +263,10 @@ class CommonTestCases: assert overflowing_tokens_first_truncated == sequence_0_no_special_tokens[-(2 + stride):] assert len(truncated_sequence) == len(sequence) - 2 assert truncated_sequence == truncated_second_sequence + + def test_tokens_sent_to_encode(self): + tokenizer = self.get_tokenizer() + + sequence = "Let's encode this sequence" + tokens = tokenizer.encode(sequence) + tokenizer.encode(tokens, add_special_tokens=True) diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 0b23d00b8a..b32d6daeb6 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -707,14 +707,14 @@ class PreTrainedTokenizer(object): """ if text_pair is None: if add_special_tokens: - sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) + sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) if isinstance(text, str) else text return self.add_special_tokens_single_sequence(sequence_tokens) else: - ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) + ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) if isinstance(text, str) else text return ids - first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] - second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] + first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] if isinstance(text, str) else text + second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] if isinstance(text_pair, str) else text_pair if add_special_tokens: return self.add_special_tokens_sequence_pair(first_sentence_tokens, second_sentence_tokens) @@ -754,7 +754,7 @@ class PreTrainedTokenizer(object): information = {} if text_pair is None: - sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) + sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) if isinstance(text, str) else text if add_special_tokens: information = self.prepare_for_model(sequence_tokens, max_length, stride) else: @@ -766,8 +766,8 @@ class PreTrainedTokenizer(object): if output_mask: information["mask"] = [0] * len(information["sequence"]) else: - first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] - second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] + first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] if isinstance(text, str) else text + second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] if isinstance(text_pair, str) else text_pair if add_special_tokens: information = self.prepare_pair_for_model(