From cf4eb8b3f9dd9adfc5b55212058a15e2e17ca071 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 6 Sep 2021 16:11:23 +0200 Subject: [PATCH] Adding a test for multibytes unicode. (#13447) * Adding a test for multibytes unicode. * Adding some accents. * Making sure decoding works. * Make tests passing by being cheesy. --- .../models/byt5/tokenization_byt5.py | 26 ++++++++++++++----- tests/test_tokenization_byt5.py | 21 +++++++++++++++ 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/byt5/tokenization_byt5.py b/src/transformers/models/byt5/tokenization_byt5.py index 36a2a53237..e5e3ecf6cf 100644 --- a/src/transformers/models/byt5/tokenization_byt5.py +++ b/src/transformers/models/byt5/tokenization_byt5.py @@ -104,7 +104,7 @@ class ByT5Tokenizer(PreTrainedTokenizer): self._num_special_tokens = len(self.special_tokens_encoder) n = len(additional_special_tokens) for i, token in enumerate(additional_special_tokens): - self.special_tokens_encoder[token] = self.vocab_size + i - n - 1 + self.special_tokens_encoder[token] = self.vocab_size + i - n self.special_tokens_decoder: Dict[str, int] = {v: k for k, v in self.special_tokens_encoder.items()} @property @@ -199,7 +199,7 @@ class ByT5Tokenizer(PreTrainedTokenizer): def _tokenize(self, text: str) -> List[str]: """Take as input a string and return a list of strings (tokens) for words/sub-words""" - tokens = list(text) + tokens = [chr(i) for i in text.encode("utf-8")] return tokens def _convert_token_to_id(self, token): @@ -224,15 +224,27 @@ class ByT5Tokenizer(PreTrainedTokenizer): def convert_tokens_to_string(self, tokens): """Converts a sequence of tokens (string) in a single string.""" - string = "" + bstring = b"" for token in tokens: if token in self.special_tokens_decoder: - tok_string = self.special_tokens_decoder[token] + tok_string = self.special_tokens_decoder[token].encode("utf-8") elif token in self.added_tokens_decoder: - tok_string = self.added_tokens_decoder[token] + tok_string = self.special_tokens_decoder[token].encode("utf-8") + elif token in self.special_tokens_encoder: + tok_string = token.encode("utf-8") + elif token in self.added_tokens_encoder: + tok_string = token.encode("utf-8") else: - tok_string = token - string += tok_string + tok_string = bytes([ord(token)]) + bstring += tok_string + # XXX: This is most likely incorrect, we want utf-8 errors + # to be triggered. However transformers test suite will + # try to decode every ID within the tokenizer on their own + # meaning it will attempt to try and decode invalid utf-8. + # Ignoring errors means passing tests, meanwhile correctly + # raising the errors means editing the automated tests to + # support that behavior (decoding an arbitrary ID might be invalid). + string = bstring.decode("utf-8", errors="ignore") return string # ByT5Tokenizer has no vocab file diff --git a/tests/test_tokenization_byt5.py b/tests/test_tokenization_byt5.py index d8d4795917..46754047bf 100644 --- a/tests/test_tokenization_byt5.py +++ b/tests/test_tokenization_byt5.py @@ -56,6 +56,27 @@ class ByT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""]) self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"]) + def test_multibytes_char(self): + tokenizer = self.t5_base_tokenizer + src_text = "Unicode €." + encoded = tokenizer(src_text) + encoded_ids = [88, 113, 108, 102, 114, 103, 104, 35, 229, 133, 175, 49, 1] + self.assertEqual(encoded["input_ids"], encoded_ids) + + # decoding + decoded = tokenizer.decode(encoded_ids) + self.assertEqual(decoded, "Unicode €.") + + encoded = tokenizer("e è é ê ë") + encoded_ids = [104, 35, 198, 171, 35, 198, 172, 35, 198, 173, 35, 198, 174, 1] + self.assertEqual(encoded["input_ids"], encoded_ids) + # decoding + decoded = tokenizer.decode(encoded_ids) + self.assertEqual(decoded, "e è é ê ë") + + # encode/decode, but with `encode` instead of `__call__` + self.assertEqual(tokenizer.decode(tokenizer.encode("e è é ê ë")), "e è é ê ë") + def test_prepare_batch_integration(self): tokenizer = self.t5_base_tokenizer src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]