From fede4ef45ddec6e7706548836b3ae2a7728fa93a Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 2 Sep 2019 02:27:39 +0200 Subject: [PATCH] fixing #1133 --- .../tests/tokenization_tests_commons.py | 7 ++++-- pytorch_transformers/tokenization_utils.py | 22 +++++++++++++++++-- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index cdc6cddf00..65f45c496c 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -104,7 +104,7 @@ class CommonTestCases: self.assertNotEqual(vocab_size, 0) self.assertEqual(vocab_size, all_size) - new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"] + new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"] added_toks = tokenizer.add_tokens(new_toks) vocab_size_2 = tokenizer.vocab_size all_size_2 = len(tokenizer) @@ -114,7 +114,9 @@ class CommonTestCases: self.assertEqual(added_toks, len(new_toks)) self.assertEqual(all_size_2, all_size + len(new_toks)) - tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l") + tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l") + out_string = tokenizer.decode(tokens) + self.assertGreaterEqual(len(tokens), 4) self.assertGreater(tokens[0], tokenizer.vocab_size - 1) self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) @@ -131,6 +133,7 @@ class CommonTestCases: self.assertEqual(all_size_3, all_size_2 + len(new_toks_2)) tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l") + out_string = tokenizer.decode(tokens) self.assertGreaterEqual(len(tokens), 6) self.assertGreater(tokens[0], tokenizer.vocab_size - 1) diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 3aed47eb09..4b52409eea 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -722,7 +722,7 @@ class PreTrainedTokenizer(object): return self._convert_id_to_token(ids) tokens = [] for index in ids: - if index in self.all_special_ids and skip_special_tokens: + if skip_special_tokens and index in self.all_special_ids: continue if index in self.added_tokens_decoder: tokens.append(self.added_tokens_decoder[index]) @@ -747,7 +747,25 @@ class PreTrainedTokenizer(object): Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``. """ filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) - text = self.convert_tokens_to_string(filtered_tokens) + + # To avoid mixing byte-level and unicode for byte-level BPT + # we need to build string separatly for added tokens and byte-level tokens + # cf. https://github.com/huggingface/pytorch-transformers/issues/1133 + sub_texts = [] + current_sub_text = [] + for token in filtered_tokens: + if skip_special_tokens and token in self.all_special_ids: + continue + if token in self.added_tokens_encoder: + if current_sub_text: + sub_texts.append(self.convert_tokens_to_string(current_sub_text)) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_texts.append(self.convert_tokens_to_string(current_sub_text)) + text = ''.join(sub_texts) if self._sep_token is not None and self._sep_token in text: text = text.replace(self._cls_token, self._sep_token)