Add tests for fast tokenizers

This commit is contained in:
Anthony MOI
2019-12-24 13:29:01 -05:00
parent 31c56f2e0b
commit 2818e50569
3 changed files with 67 additions and 1 deletions

View File

@@ -18,7 +18,7 @@ import json
import os
import unittest
from transformers.tokenization_gpt2 import VOCAB_FILES_NAMES, GPT2Tokenizer
from transformers.tokenization_gpt2 import VOCAB_FILES_NAMES, GPT2Tokenizer, GPT2TokenizerFast
from .test_tokenization_common import TokenizerTesterMixin
@@ -26,6 +26,7 @@ from .test_tokenization_common import TokenizerTesterMixin
class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = GPT2Tokenizer
test_rust_tokenizer = True
def setUp(self):
super(GPT2TokenizationTest, self).setUp()
@@ -68,6 +69,10 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
kwargs.update(self.special_tokens_map)
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return GPT2TokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "lower newer"
output_text = "lower newer"
@@ -83,3 +88,33 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
return
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer(add_special_tokens=False, add_prefix_space=True)
sequence = u"lower newer"
# Testing tokenization
tokens = tokenizer.tokenize(sequence, add_prefix_space=True)
rust_tokens = rust_tokenizer.tokenize(sequence)
self.assertListEqual(tokens, rust_tokens)
# Testing conversion to ids without special tokens
ids = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=True)
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
# Testing conversion to ids with special tokens
rust_tokenizer = self.get_rust_tokenizer(add_prefix_space=True)
ids = tokenizer.encode(sequence, add_prefix_space=True)
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
# Testing the unknown token
input_tokens = tokens + [rust_tokenizer.unk_token]
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
self.assertListEqual(rust_tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)