⚠️⚠️[T5Tokenize] Fix T5 family tokenizers⚠️⚠️ (#24565)

* don't add space before single letter chars that don't have a merge

* fix the fix

* fixup

* add a test

* more testing

* fixup

* hack to make sure fast is also fixed

* update switch transformers test

* revert convert slow

* Update src/transformers/models/t5/tokenization_t5.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add typechecking

* quality

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Arthur
2023-06-30 14:00:43 +09:00
committed by GitHub
parent 9e28750287
commit b52a03cd3b
3 changed files with 53 additions and 7 deletions

View File

@@ -1149,7 +1149,7 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
model = SwitchTransformersForConditionalGeneration.from_pretrained(
"google/switch-base-8", torch_dtype=torch.bfloat16
).eval()
tokenizer = AutoTokenizer.from_pretrained("t5-small")
tokenizer = AutoTokenizer.from_pretrained("t5-small", use_fast=False)
model = model.to(torch_device)
input_ids = tokenizer(
@@ -1160,13 +1160,13 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
self.assertEqual(output_str, "drink.")
input_ids = tokenizer(
"A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>.",
"A <extra_id_0> walks into a bar and orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>.",
return_tensors="pt",
).input_ids.to(torch_device)
sequences = model.generate(input_ids)
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=False)[0]
EXPECTED_OUTPUT = "<pad><extra_id_0> man<extra_id_1> beer<extra_id_2> a<extra_id_3> salt<extra_id_4>.</s>"
EXPECTED_OUTPUT = "<pad><extra_id_0> man<extra_id_1> beer<extra_id_2> a<extra_id_3> whiskey<extra_id_4>.</s>"
self.assertEqual(output_str, EXPECTED_OUTPUT)
def test_small_batch_generate(self):
@@ -1174,10 +1174,10 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
model = SwitchTransformersForConditionalGeneration.from_pretrained(
"google/switch-base-8", torch_dtype=torch.bfloat16
).eval()
tokenizer = AutoTokenizer.from_pretrained("t5-small")
tokenizer = AutoTokenizer.from_pretrained("t5-small", use_fast=False)
inputs = [
"A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
"A <extra_id_0> walks into a bar and orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
] * BATCH_SIZE
encoded_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt")