[pegasus] Faster tokenizer tests (#7672)
This commit is contained in:
@@ -1,13 +1,15 @@
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.testing_utils import get_tests_dir, require_torch
|
||||
from transformers.tokenization_pegasus import PegasusTokenizer, PegasusTokenizerFast
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_no_bos.model")
|
||||
|
||||
|
||||
class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = PegasusTokenizer
|
||||
@@ -17,11 +19,9 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
save_dir = Path(self.tmpdirname)
|
||||
spm_file = PegasusTokenizer.vocab_files_names["vocab_file"]
|
||||
if not (save_dir / spm_file).exists():
|
||||
tokenizer = self.pegasus_large_tokenizer
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
# We have a SentencePiece fixture for testing
|
||||
tokenizer = PegasusTokenizer(SAMPLE_VOCAB)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
@cached_property
|
||||
def pegasus_large_tokenizer(self):
|
||||
@@ -32,10 +32,7 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
pass
|
||||
|
||||
def get_tokenizer(self, **kwargs) -> PegasusTokenizer:
|
||||
if not kwargs:
|
||||
return self.pegasus_large_tokenizer
|
||||
else:
|
||||
return PegasusTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
return PegasusTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
return ("This is a test", "This is a test")
|
||||
|
||||
Reference in New Issue
Block a user