[pegasus] Faster tokenizer tests (#7672)

This commit is contained in:
Stas Bekman
2020-10-09 08:10:32 -07:00
committed by GitHub
parent bc00b37a0d
commit b0f05e0c4c
8 changed files with 51 additions and 29 deletions

Binary file not shown.

View File

@@ -1,13 +1,15 @@
import unittest
from pathlib import Path
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch
from transformers.testing_utils import get_tests_dir, require_torch
from transformers.tokenization_pegasus import PegasusTokenizer, PegasusTokenizerFast
from .test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_no_bos.model")
class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = PegasusTokenizer
@@ -17,11 +19,9 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def setUp(self):
super().setUp()
save_dir = Path(self.tmpdirname)
spm_file = PegasusTokenizer.vocab_files_names["vocab_file"]
if not (save_dir / spm_file).exists():
tokenizer = self.pegasus_large_tokenizer
tokenizer.save_pretrained(self.tmpdirname)
# We have a SentencePiece fixture for testing
tokenizer = PegasusTokenizer(SAMPLE_VOCAB)
tokenizer.save_pretrained(self.tmpdirname)
@cached_property
def pegasus_large_tokenizer(self):
@@ -32,10 +32,7 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
pass
def get_tokenizer(self, **kwargs) -> PegasusTokenizer:
if not kwargs:
return self.pegasus_large_tokenizer
else:
return PegasusTokenizer.from_pretrained(self.tmpdirname, **kwargs)
return PegasusTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self, tokenizer):
return ("This is a test", "This is a test")

View File

@@ -14,19 +14,18 @@
# limitations under the License.
import os
import unittest
from transformers import BatchEncoding
from transformers.file_utils import cached_property
from transformers.testing_utils import _torch_available
from transformers.testing_utils import _torch_available, get_tests_dir
from transformers.tokenization_t5 import T5Tokenizer, T5TokenizerFast
from transformers.tokenization_xlnet import SPIECE_UNDERLINE
from .test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
FRAMEWORK = "pt" if _torch_available else "tf"