From 3d87991f606b36dc54318ac3dee9803001ef161d Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Tue, 13 Aug 2019 12:00:24 -0400 Subject: [PATCH] Fixed error with encoding --- .../tests/tokenization_roberta_test.py | 7 +++++-- pytorch_transformers/tokenization_utils.py | 11 +++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/pytorch_transformers/tests/tokenization_roberta_test.py b/pytorch_transformers/tests/tokenization_roberta_test.py index b76b3e311d..a8f940ae43 100644 --- a/pytorch_transformers/tests/tokenization_roberta_test.py +++ b/pytorch_transformers/tests/tokenization_roberta_test.py @@ -81,11 +81,14 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): text = tokenizer.encode("sequence builders") text_2 = tokenizer.encode("multi-sequence build") + encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True) + encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True) + encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) - assert encoded_sentence == [0] + text + [2] - assert encoded_pair == [0] + text + [2, 2] + text_2 + [2] + assert encoded_sentence == encoded_text_from_decode + assert encoded_pair == encoded_pair_from_decode if __name__ == '__main__': diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 3253596058..7bb9fd9d29 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -519,24 +519,19 @@ class PreTrainedTokenizer(object): def _convert_token_to_id(self, token): raise NotImplementedError - def encode(self, text, add_special_tokens=False, *sequences): + def encode(self, text, text_pair=None, add_special_tokens=False): """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``. """ - - if len(sequences) == 0: + if text_pair is None: if add_special_tokens: return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text))) else: return self.convert_tokens_to_ids(self.tokenize(text)) - if len(sequences) > 1: - logger.warning("Tokenization currently only supports sentence pairs. Ignoring every string following the " - "initial two.") - first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text)] - second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(sequences[0])] + second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair)] if add_special_tokens: return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)