Add tests for fast tokenizers

This commit is contained in:
Anthony MOI
2019-12-24 13:29:01 -05:00
parent 31c56f2e0b
commit 2818e50569
3 changed files with 67 additions and 1 deletions

View File

@@ -21,6 +21,7 @@ from transformers.tokenization_bert import (
VOCAB_FILES_NAMES,
BasicTokenizer,
BertTokenizer,
BertTokenizerFast,
WordpieceTokenizer,
_is_control,
_is_punctuation,
@@ -34,6 +35,7 @@ from .utils import slow
class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = BertTokenizer
test_rust_tokenizer = True
def setUp(self):
super(BertTokenizationTest, self).setUp()
@@ -60,6 +62,9 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def get_tokenizer(self, **kwargs):
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "UNwant\u00E9d,running"
output_text = "unwanted, running"
@@ -72,6 +77,28 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
return
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer(add_special_tokens=False)
sequence = u"UNwant\u00E9d,running"
tokens = tokenizer.tokenize(sequence)
rust_tokens = rust_tokenizer.tokenize(sequence)
self.assertListEqual(tokens, rust_tokens)
ids = tokenizer.encode(sequence, add_special_tokens=False)
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
rust_tokenizer = self.get_rust_tokenizer()
ids = tokenizer.encode(sequence)
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
def test_chinese(self):
tokenizer = BasicTokenizer()