[Patch-t5-tokenizer] Patches the changes on T5 to make sure previous behaviour is still valide for beginning of words (#24622)

* patch `_tokenize` function

* more tests

* properly fix

* fixup

* Update src/transformers/models/t5/tokenization_t5.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix without ifs

* update

* protect import

* add python processing

* is first needed

* add doc and update with lefacy

* updaate

* fix T5 SPM converter

* styling

* fix T5 warning

* add is_seqio_available

* remove is_first

* revert some changes

* more tests and update

* update llama test batterie

* fixup

* refactor T5 spm common tests

* draft the llama tests

* update

* uopdate test

* nits

* refine

* name nit

* fix t5 tests

* fix T5

* update

* revert convert slow to fast changes that fail lots of tests

* legacy support

* fixup

* nits is first not defined

* don't use legacy behaviour for switch transformers

* style

* My attempt to check.

* nits

* fixes

* update

* fixup

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* updates

* fixup

* add legacy warning

* fixup

* warning_once nit

* update t5 documentation test

* update llama tok documentation

* add space to warning

* nits

* nit

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* last nits

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Arthur
2023-07-11 22:02:18 +09:00
committed by GitHub
parent b3ab3fac1d
commit b15343de6f
11 changed files with 387 additions and 52 deletions

View File

@@ -498,3 +498,89 @@ class LlamaIntegrationTest(unittest.TestCase):
decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True)
self.assertEqual(decoded1, decoded2)
@require_sentencepiece
@require_tokenizers
class CommonSpmIntegrationTests(unittest.TestCase):
"""
A class that regroups important test to make sure that we properly handle the special tokens.
"""
@classmethod
def setUpClass(cls):
tokenizer = LlamaTokenizer(SAMPLE_VOCAB, extra_ids=0, add_bos_token=False, legacy=False)
tokenizer.add_special_tokens({"additional_special_tokens": ["<s>"]})
tokenizer._create_trie(tokenizer.all_special_tokens)
# TODO ArthurZ the above is necessary as addedTokens / intialization sucks. Trie is not correctly created
# So the extra ids are split....
cls.tokenizer = tokenizer
return cls
def test_add_dummy_prefix(self):
# make sure `'▁'` is prepended, and outputs match sp_model's
# `sentencepiece.NormalizerSpec.add_dummy_prefix` attribute
input_ids = self.tokenizer.encode(". Hello")
self.assertEqual(input_ids, [7, 4, 156, 86, 20])
sp_encode = self.tokenizer.sp_model.encode(". Hello")
self.assertEqual(input_ids, sp_encode)
tokens = self.tokenizer.tokenize(". Hello")
self.assertEqual(tokens, ["", ".", "▁He", "ll", "o"])
def test_remove_extra_whitespaces(self):
# make sure the extra spaces are eaten. Since the sample vocab does not have
# `______`. sentencepiece.NormalizerSpec.remove_extra_whitespaces attribute is set to False
input_ids = self.tokenizer.encode(" . Hello")
self.assertEqual(input_ids, [7, 4, 156, 86, 20])
sp_encode = self.tokenizer.sp_model.encode(" . Hello")
self.assertEqual(input_ids, sp_encode)
tokens = self.tokenizer.tokenize(" . Hello")
self.assertEqual(tokens, ["", ".", "▁He", "ll", "o"])
# `'▁'` is also a whitespace
input_ids = self.tokenizer.encode("▁He is not")
self.assertEqual(input_ids, [156, 46, 44])
tokens = self.tokenizer.tokenize("▁He is not")
sp_encode = self.tokenizer.sp_model.encode("▁He is not")
self.assertEqual(input_ids, sp_encode)
self.assertEqual(tokens, ["▁He", "▁is", "▁not"]) # no extra space added
input_ids = self.tokenizer.encode("▁He is not<s> ▁He")
self.assertEqual(input_ids, [156, 46, 44, 1, 156])
tokens = self.tokenizer.tokenize("▁He is not<s> ▁He")
self.assertEqual(tokens, ["▁He", "▁is", "▁not", "<s>", "▁He"]) # spaces are eaten by spm + our strip
# make sure that the output after the extra id is the same as if
# extra_id was not there
input_ids = self.tokenizer.encode("▁He is not ▁He")
self.assertEqual(input_ids, [156, 46, 44, 156])
tokens = self.tokenizer.tokenize("▁He is not ▁He")
self.assertEqual(tokens, ["▁He", "▁is", "▁not", "▁He"]) # spaces are eaten by spm even if not start
def test_character_after_special_token(self):
# Make sure that `tokenizer.tokenize` is similar to
# adding the equivalent special token to the vocab
input_ids = self.tokenizer.encode("Hey <s>I")
self.assertEqual(input_ids, [156, 30, 1, 100])
sp_encode = self.tokenizer.sp_model.encode("Hey .I")
# the last token should be 100
self.assertEqual(input_ids[-1], sp_encode[-1])
tokens = self.tokenizer.tokenize("<s>I")
self.assertEqual(tokens, ["<s>", "I"])
input_ids = self.tokenizer.encode("Hello, <s>,")
self.assertEqual(input_ids, [156, 86, 20, 3, 1, 3])
tokens = self.tokenizer.tokenize("Hello, <s>,")
self.assertEqual(tokens, ["▁He", "ll", "o", ",", "<s>", ","])
def test_special_tokens_strip(self):
input_ids = self.tokenizer.encode(" <s> ,")
self.assertEqual(input_ids, [1, 7, 3])
tokens = self.tokenizer.tokenize(" <s> ,")
# spaces are eaten by rstrip / lstrip + spm sp_model.encode(" ") = []
self.assertEqual(tokens, ["<s>", "", ","])
input_ids = self.tokenizer.encode("No <s> ▁He")
self.assertEqual(input_ids, [284, 1, 156])
tokens = self.tokenizer.tokenize("No <s> ▁He")
self.assertEqual(tokens, ["▁No", "<s>", "▁He"]) # spaces are eaten by rstrip / lstrip

View File

@@ -1143,13 +1143,16 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
torch.testing.assert_allclose(hf_logits, EXPECTED_MEAN_LOGITS, rtol=6e-3, atol=9e-3)
@unittest.skip(
"Unless we stop stripping left and right by default for all special tokens, the expected ids obtained here will not match the original ones. Wait for https://github.com/huggingface/transformers/pull/23909 to be merged"
)
def test_small_generate(self):
# Generate test using the smalled switch-C model.
model = SwitchTransformersForConditionalGeneration.from_pretrained(
"google/switch-base-8", torch_dtype=torch.bfloat16
).eval()
tokenizer = AutoTokenizer.from_pretrained("t5-small", use_fast=False)
tokenizer = AutoTokenizer.from_pretrained("t5-small", use_fast=False, legacy=False)
model = model.to(torch_device)
input_ids = tokenizer(
@@ -1169,12 +1172,15 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
EXPECTED_OUTPUT = "<pad><extra_id_0> man<extra_id_1> beer<extra_id_2> a<extra_id_3> whiskey<extra_id_4>.</s>"
self.assertEqual(output_str, EXPECTED_OUTPUT)
@unittest.skip(
"Unless we stop stripping left and right by default for all special tokens, the expected ids obtained here will not match the original ones. Wait for https://github.com/huggingface/transformers/pull/23909 to be merged"
)
def test_small_batch_generate(self):
BATCH_SIZE = 4
model = SwitchTransformersForConditionalGeneration.from_pretrained(
"google/switch-base-8", torch_dtype=torch.bfloat16
).eval()
tokenizer = AutoTokenizer.from_pretrained("t5-small", use_fast=False)
tokenizer = AutoTokenizer.from_pretrained("t5-small", use_fast=False, legacy=False)
inputs = [
"A <extra_id_0> walks into a bar and orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."

View File

@@ -19,7 +19,7 @@ import tempfile
import unittest
from transformers import SPIECE_UNDERLINE, AddedToken, BatchEncoding, T5Tokenizer, T5TokenizerFast
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_seqio, require_tokenizers, slow
from transformers.utils import cached_property, is_tf_available, is_torch_available
from ...test_tokenization_common import TokenizerTesterMixin
@@ -381,7 +381,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_get_sentinel_tokens(self):
tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=10)
sentinel_tokens = tokenizer.get_sentinel_tokens()
self.assertEquals(len(sentinel_tokens), 10)
self.assertEqual(len(sentinel_tokens), 10)
self.assertListEqual(sorted(sentinel_tokens), sorted([f"<extra_id_{str(i)}>" for i in range(0, 10)]))
self.assertTrue([re.search(r"<extra_id_\d+>", token) is not None for token in sentinel_tokens])
@@ -392,7 +392,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_get_sentinel_tokens_for_fasttokenizer(self):
tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
sentinel_tokens = tokenizer.get_sentinel_tokens()
self.assertEquals(len(sentinel_tokens), 10)
self.assertEqual(len(sentinel_tokens), 10)
self.assertListEqual(sorted(sentinel_tokens), sorted([f"<extra_id_{str(i)}>" for i in range(0, 10)]))
self.assertTrue([re.search(r"<extra_id_\d+>", token) is not None for token in sentinel_tokens])
@@ -400,34 +400,151 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted(range(1000, 1010)))
def test_encode_extra_ids(self):
tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=0)
@require_sentencepiece
@require_tokenizers
class CommonSpmIntegrationTests(unittest.TestCase):
"""
A class that regroups important test to make sure that we properly handle the special tokens.
"""
@classmethod
def setUpClass(cls):
tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=0, legacy=False)
tokenizer.add_special_tokens({"additional_special_tokens": ["<extra_id_0>"]})
tokenizer._create_trie(tokenizer.all_special_tokens)
# TODO ArthurZ the above is necessary as addedTokens / intialization sucks. Trie is not correctly created
# So the extra ids are split....
cls.tokenizer = tokenizer
input_ids = tokenizer.encode(". Hello")
self.assertEquals(input_ids, [7, 4, 156, 86, 20, 2])
tokens = tokenizer.tokenize(". Hello")
self.assertEquals(tokens, ["", ".", "▁He", "ll", "o"])
def test_add_dummy_prefix(self):
# make sure `'▁'` is prepended, and outputs match sp_model's
# `sentencepiece.NormalizerSpec.add_dummy_prefix` attribute
input_ids = self.tokenizer.encode(". Hello", add_special_tokens=False)
self.assertEqual(input_ids, [7, 4, 156, 86, 20])
sp_encode = self.tokenizer.sp_model.encode(". Hello")
self.assertEqual(input_ids, sp_encode)
tokens = self.tokenizer.tokenize(". Hello")
self.assertEqual(tokens, ["", ".", "▁He", "ll", "o"])
input_ids = tokenizer.encode(" . Hello")
self.assertEquals(input_ids, [7, 4, 156, 86, 20, 2])
tokens = tokenizer.tokenize(" . Hello")
self.assertEquals(tokens, ["", ".", "▁He", "ll", "o"])
def test_remove_extra_whitespaces(self):
# make sure the extra spaces are eaten
# sentencepiece.NormalizerSpec.remove_extra_whitespaces attribute
input_ids = self.tokenizer.encode(" . Hello", add_special_tokens=False)
self.assertEqual(input_ids, [7, 4, 156, 86, 20])
sp_encode = self.tokenizer.sp_model.encode(" . Hello")
self.assertEqual(input_ids, sp_encode)
tokens = self.tokenizer.tokenize(" . Hello")
self.assertEqual(tokens, ["", ".", "▁He", "ll", "o"])
input_ids = tokenizer.encode("Hello, <extra_id_0>I")
self.assertEquals(input_ids, [156, 86, 20, 3, 999, 8, 2])
tokens = tokenizer.tokenize("Hello, <extra_id_0>I")
self.assertEquals(tokens, ["▁He", "ll", "o", ",", "<extra_id_0>", "▁I"])
# `'▁'` is also a whitespace
input_ids = self.tokenizer.encode("▁He is not")
self.assertEqual(input_ids, [156, 46, 44, 2])
tokens = self.tokenizer.tokenize("▁He is not")
self.assertEqual(tokens, ["▁He", "▁is", "▁not"]) # no extra space added
input_ids = tokenizer.encode("Hello, <extra_id_0>,")
self.assertEquals(input_ids, [156, 86, 20, 3, 999, 3, 2])
tokens = tokenizer.tokenize("Hello, <extra_id_0>,")
self.assertEquals(tokens, ["▁He", "ll", "o", ",", "<extra_id_0>", ","])
input_ids = self.tokenizer.encode("He is not<extra_id_0> ▁He")
# here t5x does not eat with lstrip, so there is and extra ▁He in the original one
# TODO @arthurzucker we should probably not srip right since it is done by default
# for certain models...
self.assertEqual(input_ids, [156, 46, 44, 999, 0, 2])
tokens = self.tokenizer.tokenize("▁He is not<extra_id_0> ▁He")
self.assertEqual(tokens, ["▁He", "▁is", "▁not", "<extra_id_0>", "He"]) # spaces are eaten by spm + our strip
# make sure that the output after the extra id is the same as if
# extra_id was not there
input_ids = self.tokenizer.encode("▁He is not ▁He")
self.assertEqual(input_ids, [156, 46, 44, 156, 2])
tokens = self.tokenizer.tokenize("▁He is not ▁He")
self.assertEqual(tokens, ["▁He", "▁is", "▁not", "▁He"]) # spaces are eaten by spm even if not start
input_ids = tokenizer.encode(" <extra_id_0> ,")
self.assertEquals(input_ids, [999, 3, 2])
tokens = tokenizer.tokenize(" <extra_id_0> ,")
self.assertEquals(tokens, ["<extra_id_0>", ","]) # spaces are eaten by rstrip / lstrip
def test_character_after_special_token(self):
# Make sure that `tokenizer.tokenize` is similar to
# adding the equivalent special token to the vocab
input_ids = self.tokenizer.encode("Hey <extra_id_0>I")
self.assertEqual(input_ids, [156, 30, 999, 100, 2])
tokens = self.tokenizer.tokenize("Hey <extra_id_0>I")
self.assertEqual(tokens, ["▁He", "y", "<extra_id_0>", "I"])
input_ids = self.tokenizer.encode("Hello, <extra_id_0>,")
self.assertEqual(input_ids, [156, 86, 20, 3, 999, 3, 2])
tokens = self.tokenizer.tokenize("Hello, <extra_id_0>,")
self.assertEqual(tokens, ["▁He", "ll", "o", ",", "<extra_id_0>", ","])
def test_special_tokens_strip(self):
input_ids = self.tokenizer.encode(" <extra_id_0> ,")
self.assertEqual(input_ids, [999, 3, 2])
tokens = self.tokenizer.tokenize(" <extra_id_0> ,")
# spaces are eaten by rstrip / lstrip
self.assertEqual(tokens, ["<extra_id_0>", ","])
# test with a begin of word like `▁He`
input_ids = self.tokenizer.encode("No <extra_id_0> He")
self.assertEqual(input_ids, [284, 999, 0, 2])
# spaces are eaten by rstrip / lstrip, so this is expected. Don't strip otherwise you break
tokens = self.tokenizer.tokenize("No <extra_id_0> He")
self.assertEqual(tokens, ["▁No", "<extra_id_0>", "He"])
# Make sure this does not happen if we don't strip
tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=0)
tokenizer.add_special_tokens({"bos_token": AddedToken("<bos>")})
input_ids = tokenizer.encode("No <bos> He")
self.assertEqual(input_ids, [284, 1000, 156, 2])
tokens = tokenizer.tokenize("No <bos> He")
# the first `' '` after `'No'` is eaten by spm:
self.assertEqual(tokenizer.sp_model.encode("No ", out_type=str), ["▁No"])
self.assertEqual(tokens, ["▁No", "<bos>", "▁He"])
@require_seqio
@unittest.skipIf(
os.getenv("RUN_TOKENIZER_INTEGRATION", "0") == "0",
"RUN_TOKENIZER_INTEGRATION=1 to run tokenizer integration tests",
)
def test_integration_seqio(self):
from datasets import load_dataset
from seqio import SentencePieceVocabulary
ds = load_dataset("xnli", "all_languages", split="train+test+validation")
# TODO ArthurZucker fix the 3 commented tests with #23909
input_texts = [
"Bonjour <extra_id_0>.",
# "Bonjour<extra_id_0>.", # this will fail. In T5 the special token has to be at the end.
# because in T5 they add `_<extra_id_0>` to the vocab, not `<extra_id_0>`.
" Hey <extra_id_0>I love you",
# "Hey <extra_id_0> I love you", # this will fail, we strip left, to _I vs I
# "Hey <extra_id_0>▁He", # this will fail for the same reason, we replace `_` then strip
]
import tqdm
# Test with umt5
vocab_path = "gs://t5-data/vocabs/umt5.256000/sentencepiece.model"
t5x_tokenizer = SentencePieceVocabulary(vocab_path, extra_ids=300)
hf_tokenizer = T5Tokenizer.from_pretrained("google/umt5-small", legacy=False)
for text in input_texts:
self.assertEqual(
hf_tokenizer.encode(text, add_special_tokens=False), t5x_tokenizer.tokenizer.tokenize(text), f"{text}"
)
for texts in tqdm.tqdm(ds["premise"]):
for text in texts:
self.assertEqual(
hf_tokenizer.encode(text, add_special_tokens=False),
t5x_tokenizer.tokenizer.tokenize(text),
f"{text}",
)
# Test with T5
hf_tokenizer = T5Tokenizer.from_pretrained("t5-small")
vocab_path = "gs://t5-data/vocabs/cc_all.32000/sentencepiece.model"
t5x_tokenizer = SentencePieceVocabulary(vocab_path, extra_ids=300)
for text in input_texts:
self.assertEqual(
hf_tokenizer.encode(text, add_special_tokens=False), t5x_tokenizer.tokenizer.tokenize(text), f"{text}"
)
for texts in tqdm.tqdm(ds["premise"]):
for text in texts:
self.assertEqual(
hf_tokenizer.encode(text, add_special_tokens=False),
t5x_tokenizer.tokenizer.tokenize(text),
f"{text}",
)

View File

@@ -347,13 +347,16 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
@require_tokenizers
class Umt5IntegrationTest(unittest.TestCase):
@slow
@unittest.skip(
"Unless we stop stripping left and right by default for all special tokens, the expected ids obtained here will not match the original ones. Wait for https://github.com/huggingface/transformers/pull/23909 to be merged"
)
def test_small_integration_test(self):
"""
For comparison run the kaggle notbook available here : https://www.kaggle.com/arthurzucker/umt5-inference
"""
model = UMT5ForConditionalGeneration.from_pretrained("google/umt5-small", return_dict=True).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("google/umt5-small", use_fast=False)
tokenizer = AutoTokenizer.from_pretrained("google/umt5-small", use_fast=False, legacy=False)
input_text = [
"Bonjour monsieur <extra_id_0> bien <extra_id_1>.",
"No se como puedo <extra_id_0>.",
@@ -373,7 +376,7 @@ class Umt5IntegrationTest(unittest.TestCase):
]
)
# fmt: on
self.assertEqual(input_ids, EXPECTED_IDS)
torch.testing.assert_allclose(input_ids, EXPECTED_IDS)
generated_ids = model.generate(input_ids.to(torch_device))
EXPECTED_FILLING = [
@@ -384,4 +387,4 @@ class Umt5IntegrationTest(unittest.TestCase):
"<pad><extra_id_0>nyone who<extra_id_1> drink<extra_id_2> a<extra_id_3> alcohol<extra_id_4> A<extra_id_5> A. This<extra_id_6> I<extra_id_7><extra_id_52><extra_id_53></s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>",
]
filling = tokenizer.batch_decode(generated_ids)
self.assertTrue(filling, EXPECTED_FILLING)
self.assertEqual(filling, EXPECTED_FILLING)