[Patch-t5-tokenizer] Patches the changes on T5 to make sure previous behaviour is still valide for beginning of words (#24622)
* patch `_tokenize` function * more tests * properly fix * fixup * Update src/transformers/models/t5/tokenization_t5.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix without ifs * update * protect import * add python processing * is first needed * add doc and update with lefacy * updaate * fix T5 SPM converter * styling * fix T5 warning * add is_seqio_available * remove is_first * revert some changes * more tests and update * update llama test batterie * fixup * refactor T5 spm common tests * draft the llama tests * update * uopdate test * nits * refine * name nit * fix t5 tests * fix T5 * update * revert convert slow to fast changes that fail lots of tests * legacy support * fixup * nits is first not defined * don't use legacy behaviour for switch transformers * style * My attempt to check. * nits * fixes * update * fixup * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * updates * fixup * add legacy warning * fixup * warning_once nit * update t5 documentation test * update llama tok documentation * add space to warning * nits * nit * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * last nits --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@@ -1143,13 +1143,16 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
torch.testing.assert_allclose(hf_logits, EXPECTED_MEAN_LOGITS, rtol=6e-3, atol=9e-3)
|
||||
|
||||
@unittest.skip(
|
||||
"Unless we stop stripping left and right by default for all special tokens, the expected ids obtained here will not match the original ones. Wait for https://github.com/huggingface/transformers/pull/23909 to be merged"
|
||||
)
|
||||
def test_small_generate(self):
|
||||
# Generate test using the smalled switch-C model.
|
||||
|
||||
model = SwitchTransformersForConditionalGeneration.from_pretrained(
|
||||
"google/switch-base-8", torch_dtype=torch.bfloat16
|
||||
).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-small", use_fast=False)
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-small", use_fast=False, legacy=False)
|
||||
model = model.to(torch_device)
|
||||
|
||||
input_ids = tokenizer(
|
||||
@@ -1169,12 +1172,15 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
|
||||
EXPECTED_OUTPUT = "<pad><extra_id_0> man<extra_id_1> beer<extra_id_2> a<extra_id_3> whiskey<extra_id_4>.</s>"
|
||||
self.assertEqual(output_str, EXPECTED_OUTPUT)
|
||||
|
||||
@unittest.skip(
|
||||
"Unless we stop stripping left and right by default for all special tokens, the expected ids obtained here will not match the original ones. Wait for https://github.com/huggingface/transformers/pull/23909 to be merged"
|
||||
)
|
||||
def test_small_batch_generate(self):
|
||||
BATCH_SIZE = 4
|
||||
model = SwitchTransformersForConditionalGeneration.from_pretrained(
|
||||
"google/switch-base-8", torch_dtype=torch.bfloat16
|
||||
).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-small", use_fast=False)
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-small", use_fast=False, legacy=False)
|
||||
|
||||
inputs = [
|
||||
"A <extra_id_0> walks into a bar and orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
|
||||
|
||||
Reference in New Issue
Block a user