Cleanup fast tokenizers integration (#3706)

* First pass on utility classes and python tokenizers

* finishing cleanup pass

* style and quality

* Fix tests

* Updating following @mfuntowicz comment

* style and quality

* Fix Roberta

* fix batch_size/seq_length inBatchEncoding

* add alignement methods + tests

* Fix OpenAI and Transfo-XL tokenizers

* adding trim_offsets=True default for GPT2 et RoBERTa

* style and quality

* fix tests

* add_prefix_space in roberta

* bump up tokenizers to rc7

* style

* unfortunately tensorfow does like these - removing shape/seq_len for now

* Update src/transformers/tokenization_utils.py

Co-Authored-By: Stefan Schweter <stefan@schweter.it>

* Adding doc and docstrings

* making flake8 happy

Co-authored-by: Stefan Schweter <stefan@schweter.it>
This commit is contained in:
Thomas Wolf
2020-04-18 13:43:57 +02:00
committed by GitHub
parent 60a42ef1c0
commit 827d6d6ef0
28 changed files with 1031 additions and 503 deletions

View File

@@ -1,3 +1,4 @@
import logging
import unittest
from collections import namedtuple
from itertools import takewhile
@@ -21,6 +22,10 @@ from transformers.tokenization_roberta import RobertaTokenizerFast
from transformers.tokenization_transfo_xl import TransfoXLTokenizerFast
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilingual"]
Tokenizer = namedtuple("Tokenizer", ["name", "rust_cls", "python_cls", "vocab_key", "filter"])
@@ -83,6 +88,85 @@ class CommonFastTokenizerTest(unittest.TestCase):
self.assert_add_tokens(tokenizer_r)
self.assert_offsets_mapping(tokenizer_r)
self.assert_add_special_tokens(tokenizer_r)
self.assert_alignement_methods(tokenizer_r)
def assert_alignement_methods(self, tokenizer_r):
words = ["Wonderful", "no", "inspiration", "example", "with", "subtoken"]
text = " ".join(words)
batch_size = 3
encoding = tokenizer_r.encode_plus(text, add_special_tokens=False)
batch_encoding = tokenizer_r.batch_encode_plus([text] * batch_size, add_special_tokens=False)
num_tokens = len(encoding["input_ids"])
last_word_index = len(words) - 1
last_token_index = num_tokens - 1
last_batch_index = batch_size - 1
last_char_index = len(text) - 1
# words, tokens
self.assertEqual(len(encoding.words(0)), num_tokens)
self.assertEqual(max(encoding.words(0)), last_word_index)
self.assertEqual(min(encoding.words(0)), 0)
self.assertEqual(len(batch_encoding.words(last_batch_index)), num_tokens)
self.assertEqual(max(batch_encoding.words(last_batch_index)), last_word_index)
self.assertEqual(min(batch_encoding.words(last_batch_index)), 0)
self.assertEqual(len(encoding.tokens(0)), num_tokens)
# Assert token_to_word
self.assertEqual(encoding.token_to_word(0), 0)
self.assertEqual(encoding.token_to_word(0, 0), 0)
self.assertEqual(encoding.token_to_word(last_token_index), last_word_index)
self.assertEqual(encoding.token_to_word(0, last_token_index), last_word_index)
self.assertEqual(batch_encoding.token_to_word(1, 0), 0)
self.assertEqual(batch_encoding.token_to_word(0, last_token_index), last_word_index)
self.assertEqual(batch_encoding.token_to_word(last_batch_index, last_token_index), last_word_index)
# Assert word_to_tokens
self.assertEqual(encoding.word_to_tokens(0).start, 0)
self.assertEqual(encoding.word_to_tokens(0, 0).start, 0)
self.assertEqual(encoding.word_to_tokens(last_word_index).end, last_token_index + 1)
self.assertEqual(encoding.word_to_tokens(0, last_word_index).end, last_token_index + 1)
self.assertEqual(batch_encoding.word_to_tokens(1, 0).start, 0)
self.assertEqual(batch_encoding.word_to_tokens(0, last_word_index).end, last_token_index + 1)
self.assertEqual(batch_encoding.word_to_tokens(last_batch_index, last_word_index).end, last_token_index + 1)
# Assert token_to_chars
self.assertEqual(encoding.token_to_chars(0).start, 0)
self.assertEqual(encoding.token_to_chars(0, 0).start, 0)
self.assertEqual(encoding.token_to_chars(last_token_index).end, last_char_index + 1)
self.assertEqual(encoding.token_to_chars(0, last_token_index).end, last_char_index + 1)
self.assertEqual(batch_encoding.token_to_chars(1, 0).start, 0)
self.assertEqual(batch_encoding.token_to_chars(0, last_token_index).end, last_char_index + 1)
self.assertEqual(batch_encoding.token_to_chars(last_batch_index, last_token_index).end, last_char_index + 1)
# Assert char_to_token
self.assertEqual(encoding.char_to_token(0), 0)
self.assertEqual(encoding.char_to_token(0, 0), 0)
self.assertEqual(encoding.char_to_token(last_char_index), last_token_index)
self.assertEqual(encoding.char_to_token(0, last_char_index), last_token_index)
self.assertEqual(batch_encoding.char_to_token(1, 0), 0)
self.assertEqual(batch_encoding.char_to_token(0, last_char_index), last_token_index)
self.assertEqual(batch_encoding.char_to_token(last_batch_index, last_char_index), last_token_index)
# Assert char_to_word
self.assertEqual(encoding.char_to_word(0), 0)
self.assertEqual(encoding.char_to_word(0, 0), 0)
self.assertEqual(encoding.char_to_word(last_char_index), last_word_index)
self.assertEqual(encoding.char_to_word(0, last_char_index), last_word_index)
self.assertEqual(batch_encoding.char_to_word(1, 0), 0)
self.assertEqual(batch_encoding.char_to_word(0, last_char_index), last_word_index)
self.assertEqual(batch_encoding.char_to_word(last_batch_index, last_char_index), last_word_index)
# Assert word_to_chars
self.assertEqual(encoding.word_to_chars(0).start, 0)
self.assertEqual(encoding.word_to_chars(0, 0).start, 0)
self.assertEqual(encoding.word_to_chars(last_word_index).end, last_char_index + 1)
self.assertEqual(encoding.word_to_chars(0, last_word_index).end, last_char_index + 1)
self.assertEqual(batch_encoding.word_to_chars(1, 0).start, 0)
self.assertEqual(batch_encoding.word_to_chars(0, last_word_index).end, last_char_index + 1)
self.assertEqual(batch_encoding.word_to_chars(last_batch_index, last_word_index).end, last_char_index + 1)
def assert_tokenization_python_rust_equals(self, tokenizer_p, tokenizer_r):
# Ensure basic input match
@@ -306,7 +390,6 @@ class CommonFastTokenizerTest(unittest.TestCase):
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
# Simple input
# TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding
input_r = tokenizer_r.batch_encode_plus(
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, pad_to_max_length=True
)
@@ -316,7 +399,6 @@ class CommonFastTokenizerTest(unittest.TestCase):
assert_batch_padded_input_match(input_r, input_p)
# Pair input
# TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding
input_r = tokenizer_r.batch_encode_plus(
[
("This is a simple input 1", "This is a simple input 2"),