Cleanup fast tokenizers integration (#3706)
* First pass on utility classes and python tokenizers * finishing cleanup pass * style and quality * Fix tests * Updating following @mfuntowicz comment * style and quality * Fix Roberta * fix batch_size/seq_length inBatchEncoding * add alignement methods + tests * Fix OpenAI and Transfo-XL tokenizers * adding trim_offsets=True default for GPT2 et RoBERTa * style and quality * fix tests * add_prefix_space in roberta * bump up tokenizers to rc7 * style * unfortunately tensorfow does like these - removing shape/seq_len for now * Update src/transformers/tokenization_utils.py Co-Authored-By: Stefan Schweter <stefan@schweter.it> * Adding doc and docstrings * making flake8 happy Co-authored-by: Stefan Schweter <stefan@schweter.it>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import unittest
|
||||
from collections import namedtuple
|
||||
from itertools import takewhile
|
||||
@@ -21,6 +22,10 @@ from transformers.tokenization_roberta import RobertaTokenizerFast
|
||||
from transformers.tokenization_transfo_xl import TransfoXLTokenizerFast
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilingual"]
|
||||
Tokenizer = namedtuple("Tokenizer", ["name", "rust_cls", "python_cls", "vocab_key", "filter"])
|
||||
|
||||
@@ -83,6 +88,85 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assert_add_tokens(tokenizer_r)
|
||||
self.assert_offsets_mapping(tokenizer_r)
|
||||
self.assert_add_special_tokens(tokenizer_r)
|
||||
self.assert_alignement_methods(tokenizer_r)
|
||||
|
||||
def assert_alignement_methods(self, tokenizer_r):
|
||||
words = ["Wonderful", "no", "inspiration", "example", "with", "subtoken"]
|
||||
text = " ".join(words)
|
||||
batch_size = 3
|
||||
|
||||
encoding = tokenizer_r.encode_plus(text, add_special_tokens=False)
|
||||
|
||||
batch_encoding = tokenizer_r.batch_encode_plus([text] * batch_size, add_special_tokens=False)
|
||||
num_tokens = len(encoding["input_ids"])
|
||||
|
||||
last_word_index = len(words) - 1
|
||||
last_token_index = num_tokens - 1
|
||||
last_batch_index = batch_size - 1
|
||||
last_char_index = len(text) - 1
|
||||
|
||||
# words, tokens
|
||||
self.assertEqual(len(encoding.words(0)), num_tokens)
|
||||
self.assertEqual(max(encoding.words(0)), last_word_index)
|
||||
self.assertEqual(min(encoding.words(0)), 0)
|
||||
self.assertEqual(len(batch_encoding.words(last_batch_index)), num_tokens)
|
||||
self.assertEqual(max(batch_encoding.words(last_batch_index)), last_word_index)
|
||||
self.assertEqual(min(batch_encoding.words(last_batch_index)), 0)
|
||||
self.assertEqual(len(encoding.tokens(0)), num_tokens)
|
||||
|
||||
# Assert token_to_word
|
||||
self.assertEqual(encoding.token_to_word(0), 0)
|
||||
self.assertEqual(encoding.token_to_word(0, 0), 0)
|
||||
self.assertEqual(encoding.token_to_word(last_token_index), last_word_index)
|
||||
self.assertEqual(encoding.token_to_word(0, last_token_index), last_word_index)
|
||||
self.assertEqual(batch_encoding.token_to_word(1, 0), 0)
|
||||
self.assertEqual(batch_encoding.token_to_word(0, last_token_index), last_word_index)
|
||||
self.assertEqual(batch_encoding.token_to_word(last_batch_index, last_token_index), last_word_index)
|
||||
|
||||
# Assert word_to_tokens
|
||||
self.assertEqual(encoding.word_to_tokens(0).start, 0)
|
||||
self.assertEqual(encoding.word_to_tokens(0, 0).start, 0)
|
||||
self.assertEqual(encoding.word_to_tokens(last_word_index).end, last_token_index + 1)
|
||||
self.assertEqual(encoding.word_to_tokens(0, last_word_index).end, last_token_index + 1)
|
||||
self.assertEqual(batch_encoding.word_to_tokens(1, 0).start, 0)
|
||||
self.assertEqual(batch_encoding.word_to_tokens(0, last_word_index).end, last_token_index + 1)
|
||||
self.assertEqual(batch_encoding.word_to_tokens(last_batch_index, last_word_index).end, last_token_index + 1)
|
||||
|
||||
# Assert token_to_chars
|
||||
self.assertEqual(encoding.token_to_chars(0).start, 0)
|
||||
self.assertEqual(encoding.token_to_chars(0, 0).start, 0)
|
||||
self.assertEqual(encoding.token_to_chars(last_token_index).end, last_char_index + 1)
|
||||
self.assertEqual(encoding.token_to_chars(0, last_token_index).end, last_char_index + 1)
|
||||
self.assertEqual(batch_encoding.token_to_chars(1, 0).start, 0)
|
||||
self.assertEqual(batch_encoding.token_to_chars(0, last_token_index).end, last_char_index + 1)
|
||||
self.assertEqual(batch_encoding.token_to_chars(last_batch_index, last_token_index).end, last_char_index + 1)
|
||||
|
||||
# Assert char_to_token
|
||||
self.assertEqual(encoding.char_to_token(0), 0)
|
||||
self.assertEqual(encoding.char_to_token(0, 0), 0)
|
||||
self.assertEqual(encoding.char_to_token(last_char_index), last_token_index)
|
||||
self.assertEqual(encoding.char_to_token(0, last_char_index), last_token_index)
|
||||
self.assertEqual(batch_encoding.char_to_token(1, 0), 0)
|
||||
self.assertEqual(batch_encoding.char_to_token(0, last_char_index), last_token_index)
|
||||
self.assertEqual(batch_encoding.char_to_token(last_batch_index, last_char_index), last_token_index)
|
||||
|
||||
# Assert char_to_word
|
||||
self.assertEqual(encoding.char_to_word(0), 0)
|
||||
self.assertEqual(encoding.char_to_word(0, 0), 0)
|
||||
self.assertEqual(encoding.char_to_word(last_char_index), last_word_index)
|
||||
self.assertEqual(encoding.char_to_word(0, last_char_index), last_word_index)
|
||||
self.assertEqual(batch_encoding.char_to_word(1, 0), 0)
|
||||
self.assertEqual(batch_encoding.char_to_word(0, last_char_index), last_word_index)
|
||||
self.assertEqual(batch_encoding.char_to_word(last_batch_index, last_char_index), last_word_index)
|
||||
|
||||
# Assert word_to_chars
|
||||
self.assertEqual(encoding.word_to_chars(0).start, 0)
|
||||
self.assertEqual(encoding.word_to_chars(0, 0).start, 0)
|
||||
self.assertEqual(encoding.word_to_chars(last_word_index).end, last_char_index + 1)
|
||||
self.assertEqual(encoding.word_to_chars(0, last_word_index).end, last_char_index + 1)
|
||||
self.assertEqual(batch_encoding.word_to_chars(1, 0).start, 0)
|
||||
self.assertEqual(batch_encoding.word_to_chars(0, last_word_index).end, last_char_index + 1)
|
||||
self.assertEqual(batch_encoding.word_to_chars(last_batch_index, last_word_index).end, last_char_index + 1)
|
||||
|
||||
def assert_tokenization_python_rust_equals(self, tokenizer_p, tokenizer_r):
|
||||
# Ensure basic input match
|
||||
@@ -306,7 +390,6 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
|
||||
# Simple input
|
||||
# TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
@@ -316,7 +399,6 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
assert_batch_padded_input_match(input_r, input_p)
|
||||
|
||||
# Pair input
|
||||
# TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
[
|
||||
("This is a simple input 1", "This is a simple input 2"),
|
||||
|
||||
Reference in New Issue
Block a user