Use lru_cache for tokenization tests (#36818)
* fix * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -17,12 +17,13 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
import unittest
|
||||
from functools import lru_cache
|
||||
|
||||
from transformers import SPIECE_UNDERLINE, AddedToken, BatchEncoding, T5Tokenizer, T5TokenizerFast
|
||||
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_seqio, require_tokenizers, slow
|
||||
from transformers.utils import cached_property, is_tf_available, is_torch_available
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
from ...test_tokenization_common import TokenizerTesterMixin, use_cache_if_possible
|
||||
|
||||
|
||||
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
@@ -44,12 +45,13 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
test_rust_tokenizer = True
|
||||
test_sentencepiece = True
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
|
||||
# We have a SentencePiece fixture for testing
|
||||
tokenizer = T5Tokenizer(SAMPLE_VOCAB)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
tokenizer.save_pretrained(cls.tmpdirname)
|
||||
|
||||
def test_convert_token_and_id(self):
|
||||
"""Test ``_convert_token_to_id`` and ``_convert_id_to_token``."""
|
||||
@@ -145,11 +147,19 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def t5_base_tokenizer_fast(self):
|
||||
return T5TokenizerFast.from_pretrained("google-t5/t5-base")
|
||||
|
||||
def get_tokenizer(self, **kwargs) -> T5Tokenizer:
|
||||
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||
@classmethod
|
||||
@use_cache_if_possible
|
||||
@lru_cache(maxsize=64)
|
||||
def get_tokenizer(cls, pretrained_name=None, **kwargs) -> T5Tokenizer:
|
||||
pretrained_name = pretrained_name or cls.tmpdirname
|
||||
return cls.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs) -> T5TokenizerFast:
|
||||
return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||
@classmethod
|
||||
@use_cache_if_possible
|
||||
@lru_cache(maxsize=64)
|
||||
def get_rust_tokenizer(cls, pretrained_name=None, **kwargs) -> T5TokenizerFast:
|
||||
pretrained_name = pretrained_name or cls.tmpdirname
|
||||
return cls.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
|
||||
def test_rust_and_python_full_tokenizers(self):
|
||||
if not self.test_rust_tokenizer:
|
||||
@@ -275,10 +285,10 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
added_tokens = [f"<extra_id_{i}>" for i in range(100)] + [AddedToken("<special>", lstrip=True)]
|
||||
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||
tokenizer_r = self.get_rust_tokenizer(
|
||||
pretrained_name, additional_special_tokens=added_tokens, **kwargs
|
||||
)
|
||||
tokenizer_cr = self.rust_tokenizer_class.from_pretrained(
|
||||
tokenizer_cr = self.get_rust_tokenizer(
|
||||
pretrained_name, additional_special_tokens=added_tokens, **kwargs, from_slow=True
|
||||
)
|
||||
tokenizer_p = self.tokenizer_class.from_pretrained(
|
||||
@@ -460,10 +470,8 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
EXPECTED_WITH_SPACE = [9459, 149, 33, 25, 692, 1]
|
||||
EXPECTED_WO_SPACE = [3845, 63, 149, 33, 25, 692, 1]
|
||||
|
||||
slow_ = self.tokenizer_class.from_pretrained(pretrained_name, add_prefix_space=False, legacy=False)
|
||||
fast_ = self.rust_tokenizer_class.from_pretrained(
|
||||
pretrained_name, add_prefix_space=False, legacy=False, from_slow=True
|
||||
)
|
||||
slow_ = self.get_tokenizer(pretrained_name, add_prefix_space=False, legacy=False)
|
||||
fast_ = self.get_rust_tokenizer(pretrained_name, add_prefix_space=False, legacy=False, from_slow=True)
|
||||
self.assertEqual(slow_.encode(inputs), EXPECTED_WO_SPACE)
|
||||
self.assertEqual(slow_.encode(inputs), fast_.encode(inputs))
|
||||
self.assertEqual(slow_.tokenize(inputs), ["He", "y", "▁how", "▁are", "▁you", "▁doing"])
|
||||
@@ -473,8 +481,8 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
fast_.decode(EXPECTED_WO_SPACE, skip_special_tokens=True),
|
||||
)
|
||||
|
||||
slow_ = self.tokenizer_class.from_pretrained(pretrained_name, add_prefix_space=True, legacy=False)
|
||||
fast_ = self.rust_tokenizer_class.from_pretrained(pretrained_name, add_prefix_space=True, legacy=False)
|
||||
slow_ = self.get_tokenizer(pretrained_name, add_prefix_space=True, legacy=False)
|
||||
fast_ = self.get_rust_tokenizer(pretrained_name, add_prefix_space=True, legacy=False)
|
||||
self.assertEqual(slow_.encode(inputs), EXPECTED_WITH_SPACE)
|
||||
self.assertEqual(slow_.encode(inputs), fast_.encode(inputs))
|
||||
self.assertEqual(slow_.tokenize(inputs), ["▁Hey", "▁how", "▁are", "▁you", "▁doing"])
|
||||
|
||||
Reference in New Issue
Block a user