From 5c7789d4167064f7464b8801c7488a9a2878480a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 7 Sep 2021 16:45:45 +0200 Subject: [PATCH] Fixing by correctly raising UnicodeDecodeError. (#13449) --- .../models/byt5/tokenization_byt5.py | 9 +---- tests/test_tokenization_byt5.py | 40 +++++++++++++++++++ 2 files changed, 41 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/byt5/tokenization_byt5.py b/src/transformers/models/byt5/tokenization_byt5.py index e5e3ecf6cf..bda3313f8c 100644 --- a/src/transformers/models/byt5/tokenization_byt5.py +++ b/src/transformers/models/byt5/tokenization_byt5.py @@ -237,14 +237,7 @@ class ByT5Tokenizer(PreTrainedTokenizer): else: tok_string = bytes([ord(token)]) bstring += tok_string - # XXX: This is most likely incorrect, we want utf-8 errors - # to be triggered. However transformers test suite will - # try to decode every ID within the tokenizer on their own - # meaning it will attempt to try and decode invalid utf-8. - # Ignoring errors means passing tests, meanwhile correctly - # raising the errors means editing the automated tests to - # support that behavior (decoding an arbitrary ID might be invalid). - string = bstring.decode("utf-8", errors="ignore") + string = bstring.decode("utf-8") return string # ByT5Tokenizer has no vocab file diff --git a/tests/test_tokenization_byt5.py b/tests/test_tokenization_byt5.py index 46754047bf..003e6bd51f 100644 --- a/tests/test_tokenization_byt5.py +++ b/tests/test_tokenization_byt5.py @@ -15,9 +15,11 @@ import json import os +import re import shutil import tempfile import unittest +from typing import Tuple from transformers import AddedToken, BatchEncoding, ByT5Tokenizer from transformers.file_utils import cached_property, is_tf_available, is_torch_available @@ -50,6 +52,44 @@ class ByT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): def get_tokenizer(self, **kwargs) -> ByT5Tokenizer: return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) + def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20, min_length=5) -> Tuple[str, list]: + # XXX The default common tokenizer tests assume that every ID is decodable on its own. + # This assumption is invalid for ByT5 because single bytes might not be + # valid utf-8 (byte 128 for instance). + # Here we're overriding the smallest possible method to provide + # a clean sequence without making the same assumption. + + toks = [] + for i in range(len(tokenizer)): + try: + tok = tokenizer.decode([i], clean_up_tokenization_spaces=False) + except UnicodeDecodeError: + pass + toks.append((i, tok)) + + toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks)) + toks = list(filter(lambda t: [t[0]] == tokenizer.encode(t[1], add_special_tokens=False), toks)) + if max_length is not None and len(toks) > max_length: + toks = toks[:max_length] + if min_length is not None and len(toks) < min_length and len(toks) > 0: + while len(toks) < min_length: + toks = toks + toks + # toks_str = [t[1] for t in toks] + toks_ids = [t[0] for t in toks] + + # Ensure consistency + output_txt = tokenizer.decode(toks_ids, clean_up_tokenization_spaces=False) + if " " not in output_txt and len(toks_ids) > 1: + output_txt = ( + tokenizer.decode([toks_ids[0]], clean_up_tokenization_spaces=False) + + " " + + tokenizer.decode(toks_ids[1:], clean_up_tokenization_spaces=False) + ) + if with_prefix_space: + output_txt = " " + output_txt + output_ids = tokenizer.encode(output_txt, add_special_tokens=False) + return output_txt, output_ids + def test_eos_treatment(self): tokenizer = self.t5_base_tokenizer batch_with_eos_added = tokenizer(["hi", "I went to the gym", ""])