Fixing by correctly raising UnicodeDecodeError. (#13449)

This commit is contained in:
Nicolas Patry
2021-09-07 16:45:45 +02:00
committed by GitHub
parent 79815090ea
commit 5c7789d416
2 changed files with 41 additions and 8 deletions

View File

@@ -237,14 +237,7 @@ class ByT5Tokenizer(PreTrainedTokenizer):
else: else:
tok_string = bytes([ord(token)]) tok_string = bytes([ord(token)])
bstring += tok_string bstring += tok_string
# XXX: This is most likely incorrect, we want utf-8 errors string = bstring.decode("utf-8")
# to be triggered. However transformers test suite will
# try to decode every ID within the tokenizer on their own
# meaning it will attempt to try and decode invalid utf-8.
# Ignoring errors means passing tests, meanwhile correctly
# raising the errors means editing the automated tests to
# support that behavior (decoding an arbitrary ID might be invalid).
string = bstring.decode("utf-8", errors="ignore")
return string return string
# ByT5Tokenizer has no vocab file # ByT5Tokenizer has no vocab file

View File

@@ -15,9 +15,11 @@
import json import json
import os import os
import re
import shutil import shutil
import tempfile import tempfile
import unittest import unittest
from typing import Tuple
from transformers import AddedToken, BatchEncoding, ByT5Tokenizer from transformers import AddedToken, BatchEncoding, ByT5Tokenizer
from transformers.file_utils import cached_property, is_tf_available, is_torch_available from transformers.file_utils import cached_property, is_tf_available, is_torch_available
@@ -50,6 +52,44 @@ class ByT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def get_tokenizer(self, **kwargs) -> ByT5Tokenizer: def get_tokenizer(self, **kwargs) -> ByT5Tokenizer:
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20, min_length=5) -> Tuple[str, list]:
# XXX The default common tokenizer tests assume that every ID is decodable on its own.
# This assumption is invalid for ByT5 because single bytes might not be
# valid utf-8 (byte 128 for instance).
# Here we're overriding the smallest possible method to provide
# a clean sequence without making the same assumption.
toks = []
for i in range(len(tokenizer)):
try:
tok = tokenizer.decode([i], clean_up_tokenization_spaces=False)
except UnicodeDecodeError:
pass
toks.append((i, tok))
toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks))
toks = list(filter(lambda t: [t[0]] == tokenizer.encode(t[1], add_special_tokens=False), toks))
if max_length is not None and len(toks) > max_length:
toks = toks[:max_length]
if min_length is not None and len(toks) < min_length and len(toks) > 0:
while len(toks) < min_length:
toks = toks + toks
# toks_str = [t[1] for t in toks]
toks_ids = [t[0] for t in toks]
# Ensure consistency
output_txt = tokenizer.decode(toks_ids, clean_up_tokenization_spaces=False)
if " " not in output_txt and len(toks_ids) > 1:
output_txt = (
tokenizer.decode([toks_ids[0]], clean_up_tokenization_spaces=False)
+ " "
+ tokenizer.decode(toks_ids[1:], clean_up_tokenization_spaces=False)
)
if with_prefix_space:
output_txt = " " + output_txt
output_ids = tokenizer.encode(output_txt, add_special_tokens=False)
return output_txt, output_ids
def test_eos_treatment(self): def test_eos_treatment(self):
tokenizer = self.t5_base_tokenizer tokenizer = self.t5_base_tokenizer
batch_with_eos_added = tokenizer(["hi</s>", "I went to the gym</s>", "</s>"]) batch_with_eos_added = tokenizer(["hi</s>", "I went to the gym</s>", "</s>"])