From 2818e505694ee4b5b02a9c7b51faf4dd137728d4 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Tue, 24 Dec 2019 13:29:01 -0500 Subject: [PATCH] Add tests for fast tokenizers --- tests/test_tokenization_bert.py | 27 ++++++++++++++++++++++ tests/test_tokenization_common.py | 4 ++++ tests/test_tokenization_gpt2.py | 37 ++++++++++++++++++++++++++++++- 3 files changed, 67 insertions(+), 1 deletion(-) diff --git a/tests/test_tokenization_bert.py b/tests/test_tokenization_bert.py index 24e008d734..7af6cbee73 100644 --- a/tests/test_tokenization_bert.py +++ b/tests/test_tokenization_bert.py @@ -21,6 +21,7 @@ from transformers.tokenization_bert import ( VOCAB_FILES_NAMES, BasicTokenizer, BertTokenizer, + BertTokenizerFast, WordpieceTokenizer, _is_control, _is_punctuation, @@ -34,6 +35,7 @@ from .utils import slow class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = BertTokenizer + test_rust_tokenizer = True def setUp(self): super(BertTokenizationTest, self).setUp() @@ -60,6 +62,9 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): def get_tokenizer(self, **kwargs): return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) + def get_rust_tokenizer(self, **kwargs): + return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs) + def get_input_output_texts(self): input_text = "UNwant\u00E9d,running" output_text = "unwanted, running" @@ -72,6 +77,28 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) + def test_rust_and_python_full_tokenizers(self): + if not self.test_rust_tokenizer: + return + + tokenizer = self.get_tokenizer() + rust_tokenizer = self.get_rust_tokenizer(add_special_tokens=False) + + sequence = u"UNwant\u00E9d,running" + + tokens = tokenizer.tokenize(sequence) + rust_tokens = rust_tokenizer.tokenize(sequence) + self.assertListEqual(tokens, rust_tokens) + + ids = tokenizer.encode(sequence, add_special_tokens=False) + rust_ids = rust_tokenizer.encode(sequence) + self.assertListEqual(ids, rust_ids) + + rust_tokenizer = self.get_rust_tokenizer() + ids = tokenizer.encode(sequence) + rust_ids = rust_tokenizer.encode(sequence) + self.assertListEqual(ids, rust_ids) + def test_chinese(self): tokenizer = BasicTokenizer() diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 035a0dc27f..1fa965ea40 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -23,6 +23,7 @@ import tempfile class TokenizerTesterMixin: tokenizer_class = None + test_rust_tokenizer = False def setUp(self): self.tmpdirname = tempfile.mkdtemp() @@ -33,6 +34,9 @@ class TokenizerTesterMixin: def get_tokenizer(self, **kwargs): raise NotImplementedError + def get_rust_tokenizer(self, **kwargs): + raise NotImplementedError + def get_input_output_texts(self): raise NotImplementedError diff --git a/tests/test_tokenization_gpt2.py b/tests/test_tokenization_gpt2.py index 7353d55178..fdd8026a8f 100644 --- a/tests/test_tokenization_gpt2.py +++ b/tests/test_tokenization_gpt2.py @@ -18,7 +18,7 @@ import json import os import unittest -from transformers.tokenization_gpt2 import VOCAB_FILES_NAMES, GPT2Tokenizer +from transformers.tokenization_gpt2 import VOCAB_FILES_NAMES, GPT2Tokenizer, GPT2TokenizerFast from .test_tokenization_common import TokenizerTesterMixin @@ -26,6 +26,7 @@ from .test_tokenization_common import TokenizerTesterMixin class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = GPT2Tokenizer + test_rust_tokenizer = True def setUp(self): super(GPT2TokenizationTest, self).setUp() @@ -68,6 +69,10 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): kwargs.update(self.special_tokens_map) return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs) + def get_rust_tokenizer(self, **kwargs): + kwargs.update(self.special_tokens_map) + return GPT2TokenizerFast.from_pretrained(self.tmpdirname, **kwargs) + def get_input_output_texts(self): input_text = "lower newer" output_text = "lower newer" @@ -83,3 +88,33 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): input_tokens = tokens + [tokenizer.unk_token] input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + + def test_rust_and_python_full_tokenizers(self): + if not self.test_rust_tokenizer: + return + + tokenizer = self.get_tokenizer() + rust_tokenizer = self.get_rust_tokenizer(add_special_tokens=False, add_prefix_space=True) + + sequence = u"lower newer" + + # Testing tokenization + tokens = tokenizer.tokenize(sequence, add_prefix_space=True) + rust_tokens = rust_tokenizer.tokenize(sequence) + self.assertListEqual(tokens, rust_tokens) + + # Testing conversion to ids without special tokens + ids = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=True) + rust_ids = rust_tokenizer.encode(sequence) + self.assertListEqual(ids, rust_ids) + + # Testing conversion to ids with special tokens + rust_tokenizer = self.get_rust_tokenizer(add_prefix_space=True) + ids = tokenizer.encode(sequence, add_prefix_space=True) + rust_ids = rust_tokenizer.encode(sequence) + self.assertListEqual(ids, rust_ids) + + # Testing the unknown token + input_tokens = tokens + [rust_tokenizer.unk_token] + input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] + self.assertListEqual(rust_tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)