[cleanup] test_tokenization_common.py (#4390)

This commit is contained in:
Sam Shleifer
2020-05-19 10:46:55 -04:00
committed by GitHub
parent 8f1d047148
commit 07dd7c2fd8
13 changed files with 62 additions and 98 deletions

View File

@@ -17,6 +17,7 @@
import os
import unittest
from transformers.file_utils import cached_property
from transformers.tokenization_xlm_roberta import SPIECE_UNDERLINE, XLMRobertaTokenizer
from .test_tokenization_common import TokenizerTesterMixin
@@ -37,14 +38,6 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return XLMRobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "This is a test"
output_text = "This is a test"
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True)
@@ -121,22 +114,22 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
],
)
@cached_property
def big_tokenizer(self):
return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
@slow
def test_tokenization_base_easy_symbols(self):
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
symbols = "Hello World!"
original_tokenizer_encodings = [0, 35378, 6661, 38, 2]
# xlmr = torch.hub.load('pytorch/fairseq', 'xlmr.base') # xlmr.large has same tokenizer
# xlmr.eval()
# xlmr.encode(symbols)
self.assertListEqual(original_tokenizer_encodings, tokenizer.encode(symbols))
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))
@slow
def test_tokenization_base_hard_symbols(self):
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to <unk>, such as saoneuhaoesuth'
original_tokenizer_encodings = [
0,
@@ -209,4 +202,4 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# xlmr.eval()
# xlmr.encode(symbols)
self.assertListEqual(original_tokenizer_encodings, tokenizer.encode(symbols))
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))