Add tests for fast tokenizers
This commit is contained in:
@@ -21,6 +21,7 @@ from transformers.tokenization_bert import (
|
|||||||
VOCAB_FILES_NAMES,
|
VOCAB_FILES_NAMES,
|
||||||
BasicTokenizer,
|
BasicTokenizer,
|
||||||
BertTokenizer,
|
BertTokenizer,
|
||||||
|
BertTokenizerFast,
|
||||||
WordpieceTokenizer,
|
WordpieceTokenizer,
|
||||||
_is_control,
|
_is_control,
|
||||||
_is_punctuation,
|
_is_punctuation,
|
||||||
@@ -34,6 +35,7 @@ from .utils import slow
|
|||||||
class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
tokenizer_class = BertTokenizer
|
tokenizer_class = BertTokenizer
|
||||||
|
test_rust_tokenizer = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(BertTokenizationTest, self).setUp()
|
super(BertTokenizationTest, self).setUp()
|
||||||
@@ -60,6 +62,9 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
def get_tokenizer(self, **kwargs):
|
def get_tokenizer(self, **kwargs):
|
||||||
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
|
def get_rust_tokenizer(self, **kwargs):
|
||||||
|
return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
def get_input_output_texts(self):
|
def get_input_output_texts(self):
|
||||||
input_text = "UNwant\u00E9d,running"
|
input_text = "UNwant\u00E9d,running"
|
||||||
output_text = "unwanted, running"
|
output_text = "unwanted, running"
|
||||||
@@ -72,6 +77,28 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||||
|
|
||||||
|
def test_rust_and_python_full_tokenizers(self):
|
||||||
|
if not self.test_rust_tokenizer:
|
||||||
|
return
|
||||||
|
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
rust_tokenizer = self.get_rust_tokenizer(add_special_tokens=False)
|
||||||
|
|
||||||
|
sequence = u"UNwant\u00E9d,running"
|
||||||
|
|
||||||
|
tokens = tokenizer.tokenize(sequence)
|
||||||
|
rust_tokens = rust_tokenizer.tokenize(sequence)
|
||||||
|
self.assertListEqual(tokens, rust_tokens)
|
||||||
|
|
||||||
|
ids = tokenizer.encode(sequence, add_special_tokens=False)
|
||||||
|
rust_ids = rust_tokenizer.encode(sequence)
|
||||||
|
self.assertListEqual(ids, rust_ids)
|
||||||
|
|
||||||
|
rust_tokenizer = self.get_rust_tokenizer()
|
||||||
|
ids = tokenizer.encode(sequence)
|
||||||
|
rust_ids = rust_tokenizer.encode(sequence)
|
||||||
|
self.assertListEqual(ids, rust_ids)
|
||||||
|
|
||||||
def test_chinese(self):
|
def test_chinese(self):
|
||||||
tokenizer = BasicTokenizer()
|
tokenizer = BasicTokenizer()
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import tempfile
|
|||||||
class TokenizerTesterMixin:
|
class TokenizerTesterMixin:
|
||||||
|
|
||||||
tokenizer_class = None
|
tokenizer_class = None
|
||||||
|
test_rust_tokenizer = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tmpdirname = tempfile.mkdtemp()
|
self.tmpdirname = tempfile.mkdtemp()
|
||||||
@@ -33,6 +34,9 @@ class TokenizerTesterMixin:
|
|||||||
def get_tokenizer(self, **kwargs):
|
def get_tokenizer(self, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_rust_tokenizer(self, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_input_output_texts(self):
|
def get_input_output_texts(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers.tokenization_gpt2 import VOCAB_FILES_NAMES, GPT2Tokenizer
|
from transformers.tokenization_gpt2 import VOCAB_FILES_NAMES, GPT2Tokenizer, GPT2TokenizerFast
|
||||||
|
|
||||||
from .test_tokenization_common import TokenizerTesterMixin
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
|
|
||||||
@@ -26,6 +26,7 @@ from .test_tokenization_common import TokenizerTesterMixin
|
|||||||
class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
tokenizer_class = GPT2Tokenizer
|
tokenizer_class = GPT2Tokenizer
|
||||||
|
test_rust_tokenizer = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(GPT2TokenizationTest, self).setUp()
|
super(GPT2TokenizationTest, self).setUp()
|
||||||
@@ -68,6 +69,10 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
kwargs.update(self.special_tokens_map)
|
kwargs.update(self.special_tokens_map)
|
||||||
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
|
def get_rust_tokenizer(self, **kwargs):
|
||||||
|
kwargs.update(self.special_tokens_map)
|
||||||
|
return GPT2TokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
def get_input_output_texts(self):
|
def get_input_output_texts(self):
|
||||||
input_text = "lower newer"
|
input_text = "lower newer"
|
||||||
output_text = "lower newer"
|
output_text = "lower newer"
|
||||||
@@ -83,3 +88,33 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
input_tokens = tokens + [tokenizer.unk_token]
|
input_tokens = tokens + [tokenizer.unk_token]
|
||||||
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
|
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
|
||||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||||
|
|
||||||
|
def test_rust_and_python_full_tokenizers(self):
|
||||||
|
if not self.test_rust_tokenizer:
|
||||||
|
return
|
||||||
|
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
rust_tokenizer = self.get_rust_tokenizer(add_special_tokens=False, add_prefix_space=True)
|
||||||
|
|
||||||
|
sequence = u"lower newer"
|
||||||
|
|
||||||
|
# Testing tokenization
|
||||||
|
tokens = tokenizer.tokenize(sequence, add_prefix_space=True)
|
||||||
|
rust_tokens = rust_tokenizer.tokenize(sequence)
|
||||||
|
self.assertListEqual(tokens, rust_tokens)
|
||||||
|
|
||||||
|
# Testing conversion to ids without special tokens
|
||||||
|
ids = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=True)
|
||||||
|
rust_ids = rust_tokenizer.encode(sequence)
|
||||||
|
self.assertListEqual(ids, rust_ids)
|
||||||
|
|
||||||
|
# Testing conversion to ids with special tokens
|
||||||
|
rust_tokenizer = self.get_rust_tokenizer(add_prefix_space=True)
|
||||||
|
ids = tokenizer.encode(sequence, add_prefix_space=True)
|
||||||
|
rust_ids = rust_tokenizer.encode(sequence)
|
||||||
|
self.assertListEqual(ids, rust_ids)
|
||||||
|
|
||||||
|
# Testing the unknown token
|
||||||
|
input_tokens = tokens + [rust_tokenizer.unk_token]
|
||||||
|
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
|
||||||
|
self.assertListEqual(rust_tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||||
|
|||||||
Reference in New Issue
Block a user