Use lru_cache for tokenization tests (#36818)
* fix * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -13,8 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
from functools import lru_cache
|
||||
|
||||
from tests.test_tokenization_common import TokenizerTesterMixin
|
||||
from tests.test_tokenization_common import TokenizerTesterMixin, use_cache_if_possible
|
||||
from transformers import SplinterTokenizerFast, is_tf_available, is_torch_available
|
||||
from transformers.models.splinter import SplinterTokenizer
|
||||
from transformers.testing_utils import get_tests_dir, slow
|
||||
@@ -40,20 +41,29 @@ class SplinterTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
pre_trained_model_path = "tau/splinter-base"
|
||||
|
||||
# Copied from transformers.models.siglip.SiglipTokenizationTest.setUp
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
tokenizer = SplinterTokenizer(SAMPLE_VOCAB)
|
||||
tokenizer.vocab["[UNK]"] = len(tokenizer.vocab)
|
||||
tokenizer.vocab["[QUESTION]"] = len(tokenizer.vocab)
|
||||
tokenizer.vocab["."] = len(tokenizer.vocab)
|
||||
tokenizer.add_tokens("this is a test thou shall not determine rigor truly".split())
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
tokenizer.save_pretrained(cls.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs) -> SplinterTokenizer:
|
||||
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||
@classmethod
|
||||
@use_cache_if_possible
|
||||
@lru_cache(maxsize=64)
|
||||
def get_tokenizer(cls, pretrained_name=None, **kwargs) -> SplinterTokenizer:
|
||||
pretrained_name = pretrained_name or cls.tmpdirname
|
||||
return cls.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs) -> SplinterTokenizerFast:
|
||||
return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||
@classmethod
|
||||
@use_cache_if_possible
|
||||
@lru_cache(maxsize=64)
|
||||
def get_rust_tokenizer(cls, pretrained_name=None, **kwargs) -> SplinterTokenizerFast:
|
||||
pretrained_name = pretrained_name or cls.tmpdirname
|
||||
return cls.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
|
||||
# Copied from transformers.models.siglip.SiglipTokenizationTest.test_get_vocab
|
||||
def test_get_vocab(self):
|
||||
|
||||
Reference in New Issue
Block a user