⚠️⚠️[T5Tokenize] Fix T5 family tokenizers⚠️⚠️ (#24565)
* don't add space before single letter chars that don't have a merge * fix the fix * fixup * add a test * more testing * fixup * hack to make sure fast is also fixed * update switch transformers test * revert convert slow * Update src/transformers/models/t5/tokenization_t5.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * add typechecking * quality --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -19,11 +19,15 @@ import os
|
|||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
|
|
||||||
from ...tokenization_utils import PreTrainedTokenizer
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...tokenization_utils_base import TextInput
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -51,6 +55,8 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
|||||||
"t5-11b": 512,
|
"t5-11b": 512,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SPIECE_UNDERLINE = "▁"
|
||||||
|
|
||||||
|
|
||||||
class T5Tokenizer(PreTrainedTokenizer):
|
class T5Tokenizer(PreTrainedTokenizer):
|
||||||
"""
|
"""
|
||||||
@@ -294,9 +300,17 @@ class T5Tokenizer(PreTrainedTokenizer):
|
|||||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||||
self.sp_model.Load(self.vocab_file)
|
self.sp_model.Load(self.vocab_file)
|
||||||
|
|
||||||
|
def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
|
||||||
|
if not text.startswith(" "):
|
||||||
|
text = " " + text
|
||||||
|
return super().tokenize(text, **kwargs)
|
||||||
|
|
||||||
def _tokenize(self, text: str) -> List[str]:
|
def _tokenize(self, text: str) -> List[str]:
|
||||||
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
|
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
|
||||||
return self.sp_model.encode(text, out_type=str)
|
tokens = self.sp_model.encode(text, out_type=str)
|
||||||
|
if not text.startswith(" ") and tokens[0] == SPIECE_UNDERLINE:
|
||||||
|
tokens = tokens[1:]
|
||||||
|
return tokens
|
||||||
|
|
||||||
def _convert_token_to_id(self, token):
|
def _convert_token_to_id(self, token):
|
||||||
"""Converts a token (str) in an id using the vocab."""
|
"""Converts a token (str) in an id using the vocab."""
|
||||||
|
|||||||
@@ -1149,7 +1149,7 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
|
|||||||
model = SwitchTransformersForConditionalGeneration.from_pretrained(
|
model = SwitchTransformersForConditionalGeneration.from_pretrained(
|
||||||
"google/switch-base-8", torch_dtype=torch.bfloat16
|
"google/switch-base-8", torch_dtype=torch.bfloat16
|
||||||
).eval()
|
).eval()
|
||||||
tokenizer = AutoTokenizer.from_pretrained("t5-small")
|
tokenizer = AutoTokenizer.from_pretrained("t5-small", use_fast=False)
|
||||||
model = model.to(torch_device)
|
model = model.to(torch_device)
|
||||||
|
|
||||||
input_ids = tokenizer(
|
input_ids = tokenizer(
|
||||||
@@ -1160,13 +1160,13 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
|
|||||||
self.assertEqual(output_str, "drink.")
|
self.assertEqual(output_str, "drink.")
|
||||||
|
|
||||||
input_ids = tokenizer(
|
input_ids = tokenizer(
|
||||||
"A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>.",
|
"A <extra_id_0> walks into a bar and orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>.",
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
).input_ids.to(torch_device)
|
).input_ids.to(torch_device)
|
||||||
sequences = model.generate(input_ids)
|
sequences = model.generate(input_ids)
|
||||||
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=False)[0]
|
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=False)[0]
|
||||||
|
|
||||||
EXPECTED_OUTPUT = "<pad><extra_id_0> man<extra_id_1> beer<extra_id_2> a<extra_id_3> salt<extra_id_4>.</s>"
|
EXPECTED_OUTPUT = "<pad><extra_id_0> man<extra_id_1> beer<extra_id_2> a<extra_id_3> whiskey<extra_id_4>.</s>"
|
||||||
self.assertEqual(output_str, EXPECTED_OUTPUT)
|
self.assertEqual(output_str, EXPECTED_OUTPUT)
|
||||||
|
|
||||||
def test_small_batch_generate(self):
|
def test_small_batch_generate(self):
|
||||||
@@ -1174,10 +1174,10 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
|
|||||||
model = SwitchTransformersForConditionalGeneration.from_pretrained(
|
model = SwitchTransformersForConditionalGeneration.from_pretrained(
|
||||||
"google/switch-base-8", torch_dtype=torch.bfloat16
|
"google/switch-base-8", torch_dtype=torch.bfloat16
|
||||||
).eval()
|
).eval()
|
||||||
tokenizer = AutoTokenizer.from_pretrained("t5-small")
|
tokenizer = AutoTokenizer.from_pretrained("t5-small", use_fast=False)
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
"A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
|
"A <extra_id_0> walks into a bar and orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
|
||||||
] * BATCH_SIZE
|
] * BATCH_SIZE
|
||||||
encoded_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt")
|
encoded_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt")
|
||||||
|
|
||||||
|
|||||||
@@ -399,3 +399,35 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
def test_get_sentinel_token_ids_for_fasttokenizer(self):
|
def test_get_sentinel_token_ids_for_fasttokenizer(self):
|
||||||
tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
|
tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
|
||||||
self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted(range(1000, 1010)))
|
self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted(range(1000, 1010)))
|
||||||
|
|
||||||
|
def test_encode_extra_ids(self):
|
||||||
|
tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=0)
|
||||||
|
tokenizer.add_special_tokens({"additional_special_tokens": ["<extra_id_0>"]})
|
||||||
|
tokenizer._create_trie(tokenizer.all_special_tokens)
|
||||||
|
# TODO ArthurZ the above is necessary as addedTokens / intialization sucks. Trie is not correctly created
|
||||||
|
# So the extra ids are split....
|
||||||
|
|
||||||
|
input_ids = tokenizer.encode(". Hello")
|
||||||
|
self.assertEquals(input_ids, [7, 4, 156, 86, 20, 2])
|
||||||
|
tokens = tokenizer.tokenize(". Hello")
|
||||||
|
self.assertEquals(tokens, ["▁", ".", "▁He", "ll", "o"])
|
||||||
|
|
||||||
|
input_ids = tokenizer.encode(" . Hello")
|
||||||
|
self.assertEquals(input_ids, [7, 4, 156, 86, 20, 2])
|
||||||
|
tokens = tokenizer.tokenize(" . Hello")
|
||||||
|
self.assertEquals(tokens, ["▁", ".", "▁He", "ll", "o"])
|
||||||
|
|
||||||
|
input_ids = tokenizer.encode("Hello, <extra_id_0>I")
|
||||||
|
self.assertEquals(input_ids, [156, 86, 20, 3, 999, 8, 2])
|
||||||
|
tokens = tokenizer.tokenize("Hello, <extra_id_0>I")
|
||||||
|
self.assertEquals(tokens, ["▁He", "ll", "o", ",", "<extra_id_0>", "▁I"])
|
||||||
|
|
||||||
|
input_ids = tokenizer.encode("Hello, <extra_id_0>,")
|
||||||
|
self.assertEquals(input_ids, [156, 86, 20, 3, 999, 3, 2])
|
||||||
|
tokens = tokenizer.tokenize("Hello, <extra_id_0>,")
|
||||||
|
self.assertEquals(tokens, ["▁He", "ll", "o", ",", "<extra_id_0>", ","])
|
||||||
|
|
||||||
|
input_ids = tokenizer.encode(" <extra_id_0> ,")
|
||||||
|
self.assertEquals(input_ids, [999, 3, 2])
|
||||||
|
tokens = tokenizer.tokenize(" <extra_id_0> ,")
|
||||||
|
self.assertEquals(tokens, ["<extra_id_0>", ","]) # spaces are eaten by rstrip / lstrip
|
||||||
|
|||||||
Reference in New Issue
Block a user