Fixing by correctly raising UnicodeDecodeError. (#13449)
This commit is contained in:
@@ -237,14 +237,7 @@ class ByT5Tokenizer(PreTrainedTokenizer):
|
||||
else:
|
||||
tok_string = bytes([ord(token)])
|
||||
bstring += tok_string
|
||||
# XXX: This is most likely incorrect, we want utf-8 errors
|
||||
# to be triggered. However transformers test suite will
|
||||
# try to decode every ID within the tokenizer on their own
|
||||
# meaning it will attempt to try and decode invalid utf-8.
|
||||
# Ignoring errors means passing tests, meanwhile correctly
|
||||
# raising the errors means editing the automated tests to
|
||||
# support that behavior (decoding an arbitrary ID might be invalid).
|
||||
string = bstring.decode("utf-8", errors="ignore")
|
||||
string = bstring.decode("utf-8")
|
||||
return string
|
||||
|
||||
# ByT5Tokenizer has no vocab file
|
||||
|
||||
@@ -15,9 +15,11 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
from transformers import AddedToken, BatchEncoding, ByT5Tokenizer
|
||||
from transformers.file_utils import cached_property, is_tf_available, is_torch_available
|
||||
@@ -50,6 +52,44 @@ class ByT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def get_tokenizer(self, **kwargs) -> ByT5Tokenizer:
|
||||
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20, min_length=5) -> Tuple[str, list]:
|
||||
# XXX The default common tokenizer tests assume that every ID is decodable on its own.
|
||||
# This assumption is invalid for ByT5 because single bytes might not be
|
||||
# valid utf-8 (byte 128 for instance).
|
||||
# Here we're overriding the smallest possible method to provide
|
||||
# a clean sequence without making the same assumption.
|
||||
|
||||
toks = []
|
||||
for i in range(len(tokenizer)):
|
||||
try:
|
||||
tok = tokenizer.decode([i], clean_up_tokenization_spaces=False)
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
toks.append((i, tok))
|
||||
|
||||
toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks))
|
||||
toks = list(filter(lambda t: [t[0]] == tokenizer.encode(t[1], add_special_tokens=False), toks))
|
||||
if max_length is not None and len(toks) > max_length:
|
||||
toks = toks[:max_length]
|
||||
if min_length is not None and len(toks) < min_length and len(toks) > 0:
|
||||
while len(toks) < min_length:
|
||||
toks = toks + toks
|
||||
# toks_str = [t[1] for t in toks]
|
||||
toks_ids = [t[0] for t in toks]
|
||||
|
||||
# Ensure consistency
|
||||
output_txt = tokenizer.decode(toks_ids, clean_up_tokenization_spaces=False)
|
||||
if " " not in output_txt and len(toks_ids) > 1:
|
||||
output_txt = (
|
||||
tokenizer.decode([toks_ids[0]], clean_up_tokenization_spaces=False)
|
||||
+ " "
|
||||
+ tokenizer.decode(toks_ids[1:], clean_up_tokenization_spaces=False)
|
||||
)
|
||||
if with_prefix_space:
|
||||
output_txt = " " + output_txt
|
||||
output_ids = tokenizer.encode(output_txt, add_special_tokens=False)
|
||||
return output_txt, output_ids
|
||||
|
||||
def test_eos_treatment(self):
|
||||
tokenizer = self.t5_base_tokenizer
|
||||
batch_with_eos_added = tokenizer(["hi</s>", "I went to the gym</s>", "</s>"])
|
||||
|
||||
Reference in New Issue
Block a user