[HUGE] Refactoring tokenizers backend - padding - truncation - pre-tokenized pipeline - fast tokenizers - tests (#4510)
* Use tokenizers pre-tokenized pipeline * failing pretrokenized test * Fix is_pretokenized in python * add pretokenized tests * style and quality * better tests for batched pretokenized inputs * tokenizers clean up - new padding_strategy - split the files * [HUGE] refactoring tokenizers - padding - truncation - tests * style and quality * bump up requied tokenizers version to 0.8.0-rc1 * switched padding/truncation API - simpler better backward compat * updating tests for custom tokenizers * style and quality - tests on pad * fix QA pipeline * fix backward compatibility for max_length only * style and quality * Various cleans up - add verbose * fix tests * update docstrings * Fix tests * Docs reformatted * __call__ method documented Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -27,7 +27,7 @@ logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilingual"]
|
||||
Tokenizer = namedtuple("Tokenizer", ["name", "rust_cls", "python_cls", "vocab_key", "filter"])
|
||||
Tokenizer = namedtuple("Tokenizer", ["name", "rust_cls", "python_cls", "vocab_key", "filter", "kwargs"])
|
||||
|
||||
|
||||
def filter_non_english(_: Tokenizer, pretrained_name: str):
|
||||
@@ -60,10 +60,10 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name)
|
||||
tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name)
|
||||
|
||||
self.fast_align_python(tokenizer_r, tokenizer_p)
|
||||
self.fast_align_python(tokenizer_r, tokenizer_p, tok_case, pretrained_name)
|
||||
self.fast_only(tokenizer_r)
|
||||
|
||||
def fast_align_python(self, tokenizer_r, tokenizer_p):
|
||||
def fast_align_python(self, tokenizer_r, tokenizer_p, tok_case, pretrained_name):
|
||||
# Check is_fast is set correctly
|
||||
self.assertFalse(tokenizer_p.is_fast)
|
||||
self.assertTrue(tokenizer_r.is_fast)
|
||||
@@ -75,6 +75,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assert_special_tokens_map_equal(tokenizer_r, tokenizer_p)
|
||||
self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p)
|
||||
self.assert_padding(tokenizer_r, tokenizer_p)
|
||||
self.assert_pretokenized_inputs(tokenizer_r, tokenizer_p)
|
||||
self.assert_create_token_type_ids(tokenizer_r, tokenizer_p)
|
||||
# TODO: enable for v3.0.0
|
||||
# self.assert_empty_output_no_special_tokens(tokenizer_r, tokenizer_p)
|
||||
@@ -90,6 +91,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assert_offsets_mapping(tokenizer_r)
|
||||
self.assert_add_special_tokens(tokenizer_r)
|
||||
self.assert_alignement_methods(tokenizer_r)
|
||||
self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
|
||||
|
||||
def assert_alignement_methods(self, tokenizer_r):
|
||||
words = ["Wonderful", "no", "inspiration", "example", "with", "subtoken"]
|
||||
@@ -169,7 +171,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assertEqual(batch_encoding.word_to_chars(0, last_word_index).end, last_char_index + 1)
|
||||
self.assertEqual(batch_encoding.word_to_chars(last_batch_index, last_word_index).end, last_char_index + 1)
|
||||
|
||||
def assert_tokenization_python_rust_equals(self, tokenizer_p, tokenizer_r):
|
||||
def assert_tokenization_python_rust_equals(self, tokenizer_r, tokenizer_p):
|
||||
# Ensure basic input match
|
||||
input_p = tokenizer_p.encode_plus(self._data)
|
||||
input_r = tokenizer_r.encode_plus(self._data)
|
||||
@@ -184,18 +186,22 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assertSequenceEqual(input_pairs_p[key], input_pairs_r[key])
|
||||
|
||||
# Ensure truncation match
|
||||
input_p = tokenizer_p.encode_plus(self._data, max_length=512)
|
||||
input_r = tokenizer_r.encode_plus(self._data, max_length=512)
|
||||
input_p = tokenizer_p.encode_plus(self._data, max_length=512, truncation=True)
|
||||
input_r = tokenizer_r.encode_plus(self._data, max_length=512, truncation=True)
|
||||
|
||||
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
|
||||
self.assertSequenceEqual(input_p[key], input_r[key])
|
||||
|
||||
# Ensure truncation with stride match
|
||||
input_p = tokenizer_p.encode_plus(self._data, max_length=512, stride=3, return_overflowing_tokens=True)
|
||||
input_r = tokenizer_r.encode_plus(self._data, max_length=512, stride=3, return_overflowing_tokens=True)
|
||||
input_p = tokenizer_p.encode_plus(
|
||||
self._data, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True
|
||||
)
|
||||
input_r = tokenizer_r.encode_plus(
|
||||
self._data, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True
|
||||
)
|
||||
|
||||
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
|
||||
self.assertSequenceEqual(input_p[key], input_r[key])
|
||||
self.assertSequenceEqual(input_p[key], input_r[key][0])
|
||||
|
||||
def assert_num_special_tokens_to_add_equal(self, tokenizer_r, tokenizer_p):
|
||||
# Check we have the same number of added_tokens for both pair and non-pair inputs.
|
||||
@@ -274,9 +280,14 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
"""
|
||||
returned_tensor = "pt" if is_torch_available() else "tf"
|
||||
|
||||
if not tokenizer.pad_token or tokenizer.pad_token_id < 0:
|
||||
return
|
||||
|
||||
tokens = tokenizer.encode_plus(
|
||||
"HuggingFace is solving NLP one commit at a time",
|
||||
max_length=6,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors=returned_tensor,
|
||||
return_overflowing_tokens=True,
|
||||
)
|
||||
@@ -288,7 +299,8 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
tokens = tokenizer.batch_encode_plus(
|
||||
["HuggingFace is solving NLP one commit at a time"],
|
||||
max_length=6,
|
||||
pad_to_max_len=True,
|
||||
padding=True,
|
||||
truncation="only_first",
|
||||
return_tensors=returned_tensor,
|
||||
return_overflowing_tokens=True,
|
||||
)
|
||||
@@ -301,7 +313,8 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
tokens = tokenizer.batch_encode_plus(
|
||||
["HuggingFace is solving NLP one commit at a time", "Very tiny input"],
|
||||
max_length=6,
|
||||
pad_to_max_len=True,
|
||||
padding=True,
|
||||
truncation="only_first",
|
||||
return_tensors=returned_tensor,
|
||||
return_overflowing_tokens=True,
|
||||
)
|
||||
@@ -310,6 +323,58 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assertEqual(len(tokens[key].shape), 2)
|
||||
self.assertEqual(tokens[key].shape[-1], 6)
|
||||
|
||||
def assert_pretokenized_inputs(self, tokenizer_r, tokenizer_p):
|
||||
# Input string
|
||||
pretokenized_input_simple = "This is a sample input".split()
|
||||
pretokenized_input_pair = "This is a sample pair".split()
|
||||
|
||||
# Test encode for pretokenized inputs
|
||||
output_r = tokenizer_r.encode(pretokenized_input_simple, is_pretokenized=True)
|
||||
output_p = tokenizer_p.encode(pretokenized_input_simple, is_pretokenized=True)
|
||||
self.assertEqual(output_p, output_r)
|
||||
|
||||
kwargs = {
|
||||
"is_pretokenized": True,
|
||||
"return_token_type_ids": True,
|
||||
"return_attention_mask": True,
|
||||
"return_overflowing_tokens": False,
|
||||
"return_special_tokens_mask": True,
|
||||
"return_offsets_mapping": False, # Not implemented in python tokenizers
|
||||
}
|
||||
# Test encode_plus for pretokenized inputs
|
||||
output_r = tokenizer_r.encode_plus(pretokenized_input_simple, **kwargs)
|
||||
output_p = tokenizer_p.encode_plus(pretokenized_input_simple, **kwargs)
|
||||
for key in output_p.keys():
|
||||
self.assertEqual(output_p[key], output_r[key])
|
||||
|
||||
# Test batch_encode_plus for pretokenized inputs
|
||||
input_batch = ([pretokenized_input_simple] * 2) + [pretokenized_input_simple + pretokenized_input_pair]
|
||||
output_r = tokenizer_r.batch_encode_plus(input_batch, **kwargs)
|
||||
output_p = tokenizer_p.batch_encode_plus(input_batch, **kwargs)
|
||||
for key in output_p.keys():
|
||||
self.assertEqual(output_p[key], output_r[key])
|
||||
|
||||
# Test encode for pretokenized inputs pairs
|
||||
output_r = tokenizer_r.encode(pretokenized_input_simple, pretokenized_input_pair, is_pretokenized=True)
|
||||
output_p = tokenizer_p.encode(pretokenized_input_simple, pretokenized_input_pair, is_pretokenized=True)
|
||||
self.assertEqual(output_p, output_r)
|
||||
|
||||
# Test encode_plus for pretokenized inputs
|
||||
output_r = tokenizer_r.encode_plus(pretokenized_input_simple, pretokenized_input_pair, **kwargs)
|
||||
output_p = tokenizer_p.encode_plus(pretokenized_input_simple, pretokenized_input_pair, **kwargs)
|
||||
for key in output_p.keys():
|
||||
self.assertEqual(output_p[key], output_r[key])
|
||||
|
||||
# Test batch_encode_plus for pretokenized inputs
|
||||
input_batch_pair = ([pretokenized_input_simple, pretokenized_input_pair] * 2) + [
|
||||
pretokenized_input_simple + pretokenized_input_pair,
|
||||
pretokenized_input_pair,
|
||||
]
|
||||
output_r = tokenizer_r.batch_encode_plus(input_batch_pair, **kwargs)
|
||||
output_p = tokenizer_p.batch_encode_plus(input_batch_pair, **kwargs)
|
||||
for key in output_p.keys():
|
||||
self.assertEqual(output_p[key], output_r[key])
|
||||
|
||||
def assert_create_token_type_ids(self, tokenizer_r, tokenizer_p):
|
||||
input_simple = [1, 2, 3]
|
||||
input_pair = [1, 2, 3]
|
||||
@@ -357,17 +422,22 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
def assert_padded_input_match(input_r: list, input_p: list, max_length: int):
|
||||
|
||||
# Ensure we match max_length
|
||||
self.assertEqual(len(input_r), max_length), self.assertEqual(len(input_p), max_length)
|
||||
self.assertEqual(len(input_r), max_length)
|
||||
self.assertEqual(len(input_p), max_length)
|
||||
|
||||
# Ensure the number of padded tokens is the same
|
||||
padded_tokens_r = list(takewhile(lambda i: i == tokenizer_r.pad_token_id, reversed(input_r)))
|
||||
padded_tokens_p = list(takewhile(lambda i: i == tokenizer_p.pad_token_id, reversed(input_p)))
|
||||
self.assertSequenceEqual(padded_tokens_r, padded_tokens_p)
|
||||
|
||||
def assert_batch_padded_input_match(input_r: dict, input_p: dict):
|
||||
def assert_batch_padded_input_match(input_r: dict, input_p: dict, max_length: int):
|
||||
for i_r in input_r.values():
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), 15), self.assertEqual(len(i_r[1]), 15)
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), 15), self.assertEqual(len(i_r[1]), 15)
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
|
||||
len(i_r[1]), max_length
|
||||
)
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
|
||||
len(i_r[1]), max_length
|
||||
)
|
||||
|
||||
for i_r, i_p in zip(input_r["input_ids"], input_p["input_ids"]):
|
||||
assert_padded_input_match(i_r, i_p, max_length)
|
||||
@@ -375,12 +445,19 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
for i_r, i_p in zip(input_r["attention_mask"], input_p["attention_mask"]):
|
||||
self.assertSequenceEqual(i_r, i_p)
|
||||
|
||||
# Simple input
|
||||
# Encode - Simple input
|
||||
input_r = tokenizer_r.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
|
||||
input_p = tokenizer_p.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
input_r = tokenizer_r.encode("This is a simple input", max_length=max_length, padding="max_length")
|
||||
input_p = tokenizer_p.encode("This is a simple input", max_length=max_length, padding="max_length")
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
|
||||
# Pair input
|
||||
input_r = tokenizer_r.encode("This is a simple input", padding="longest")
|
||||
input_p = tokenizer_p.encode("This is a simple input", padding=True)
|
||||
assert_padded_input_match(input_r, input_p, len(input_r))
|
||||
|
||||
# Encode - Pair input
|
||||
input_r = tokenizer_r.encode(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
@@ -388,14 +465,34 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
input_r = tokenizer_r.encode(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
)
|
||||
input_p = tokenizer_p.encode(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
)
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
input_r = tokenizer_r.encode("This is a simple input", "This is a pair", padding=True)
|
||||
input_p = tokenizer_p.encode("This is a simple input", "This is a pair", padding="longest")
|
||||
assert_padded_input_match(input_r, input_p, len(input_r))
|
||||
|
||||
# Simple input
|
||||
# Encode_plus - Simple input
|
||||
input_r = tokenizer_r.encode_plus("This is a simple input", max_length=max_length, pad_to_max_length=True)
|
||||
input_p = tokenizer_p.encode_plus("This is a simple input", max_length=max_length, pad_to_max_length=True)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
input_r = tokenizer_r.encode_plus("This is a simple input", max_length=max_length, padding="max_length")
|
||||
input_p = tokenizer_p.encode_plus("This is a simple input", max_length=max_length, padding="max_length")
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
|
||||
# Pair input
|
||||
input_r = tokenizer_r.encode_plus("This is a simple input", padding="longest")
|
||||
input_p = tokenizer_p.encode_plus("This is a simple input", padding=True)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
|
||||
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
|
||||
# Encode_plus - Pair input
|
||||
input_r = tokenizer_r.encode_plus(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
@@ -404,34 +501,130 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
input_r = tokenizer_r.encode_plus(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
)
|
||||
input_p = tokenizer_p.encode_plus(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
input_r = tokenizer_r.encode_plus("This is a simple input", "This is a pair", padding="longest")
|
||||
input_p = tokenizer_p.encode_plus("This is a simple input", "This is a pair", padding=True)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
|
||||
# Simple input
|
||||
# Batch_encode_plus - Simple input
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
input_p = tokenizer_p.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p)
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
|
||||
# Pair input
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, padding="max_length",
|
||||
)
|
||||
input_p = tokenizer_p.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, padding="max_length",
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, padding="longest",
|
||||
)
|
||||
input_p = tokenizer_p.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, padding=True,
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], padding="longest"
|
||||
)
|
||||
input_p = tokenizer_p.batch_encode_plus(["This is a simple input 1", "This is a simple input 2"], padding=True)
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
|
||||
# Batch_encode_plus - Pair input
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
[
|
||||
("This is a simple input 1", "This is a simple input 2"),
|
||||
("This is a simple pair 1", "This is a simple pair 2"),
|
||||
],
|
||||
max_length=15,
|
||||
pad_to_max_length=True,
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
)
|
||||
input_p = tokenizer_p.batch_encode_plus(
|
||||
[
|
||||
("This is a simple input 1", "This is a simple input 2"),
|
||||
("This is a simple pair 1", "This is a simple pair 2"),
|
||||
],
|
||||
max_length=15,
|
||||
pad_to_max_length=True,
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p)
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
[
|
||||
("This is a simple input 1", "This is a simple input 2"),
|
||||
("This is a simple pair 1", "This is a simple pair 2"),
|
||||
],
|
||||
padding=True,
|
||||
)
|
||||
input_p = tokenizer_p.batch_encode_plus(
|
||||
[
|
||||
("This is a simple input 1", "This is a simple input 2"),
|
||||
("This is a simple pair 1", "This is a simple pair 2"),
|
||||
],
|
||||
padding="longest",
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
|
||||
# Using pad on single examples after tokenization
|
||||
input_r = tokenizer_r.encode_plus("This is a input 1")
|
||||
input_r = tokenizer_r.pad(input_r)
|
||||
|
||||
input_p = tokenizer_r.encode_plus("This is a input 1")
|
||||
input_p = tokenizer_r.pad(input_p)
|
||||
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
|
||||
|
||||
# Using pad on single examples after tokenization
|
||||
input_r = tokenizer_r.encode_plus("This is a input 1")
|
||||
input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length")
|
||||
|
||||
input_p = tokenizer_r.encode_plus("This is a input 1")
|
||||
input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
|
||||
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
|
||||
# Using pad after tokenization
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a input 1", "This is a much longer input whilch should be padded"]
|
||||
)
|
||||
input_r = tokenizer_r.pad(input_r)
|
||||
|
||||
input_p = tokenizer_r.batch_encode_plus(
|
||||
["This is a input 1", "This is a much longer input whilch should be padded"]
|
||||
)
|
||||
input_p = tokenizer_r.pad(input_p)
|
||||
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
|
||||
# Using pad after tokenization
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a input 1", "This is a much longer input whilch should be padded"]
|
||||
)
|
||||
input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length")
|
||||
|
||||
input_p = tokenizer_r.batch_encode_plus(
|
||||
["This is a input 1", "This is a much longer input whilch should be padded"]
|
||||
)
|
||||
input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
|
||||
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
|
||||
def assert_save_pretrained(self, tokenizer_r, tokenizer_p):
|
||||
# Checks it save with the same files
|
||||
@@ -503,8 +696,10 @@ class WordPieceFastTokenizerTest(CommonFastTokenizerTest):
|
||||
|
||||
TOKENIZERS_CLASSES = frozenset(
|
||||
[
|
||||
Tokenizer("Bert", BertTokenizerFast, BertTokenizer, "vocab_file", filter_non_english),
|
||||
Tokenizer("DistilBert", DistilBertTokenizerFast, DistilBertTokenizer, "vocab_file", filter_non_english),
|
||||
Tokenizer("Bert", BertTokenizerFast, BertTokenizer, "vocab_file", filter_non_english, None),
|
||||
Tokenizer(
|
||||
"DistilBert", DistilBertTokenizerFast, DistilBertTokenizer, "vocab_file", filter_non_english, None
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -552,7 +747,7 @@ class WordPieceFastTokenizerTest(CommonFastTokenizerTest):
|
||||
|
||||
class RobertaFastTokenizerTest(CommonFastTokenizerTest):
|
||||
TOKENIZERS_CLASSES = frozenset(
|
||||
[Tokenizer("Roberta", RobertaTokenizerFast, RobertaTokenizer, "vocab_file", filter_roberta_detectors)]
|
||||
[Tokenizer("Roberta", RobertaTokenizerFast, RobertaTokenizer, "vocab_file", filter_roberta_detectors, None)]
|
||||
)
|
||||
|
||||
def assert_embeded_special_tokens(self, tokenizer_r, tokenizer_p):
|
||||
@@ -580,10 +775,30 @@ class RobertaFastTokenizerTest(CommonFastTokenizerTest):
|
||||
|
||||
class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
|
||||
TOKENIZERS_CLASSES = [
|
||||
Tokenizer("OpenAI GPT", OpenAIGPTTokenizerFast, OpenAIGPTTokenizer, "vocab_file", None),
|
||||
Tokenizer("GPT2", GPT2TokenizerFast, GPT2Tokenizer, "vocab_file", None),
|
||||
Tokenizer("OpenAI GPT", OpenAIGPTTokenizerFast, OpenAIGPTTokenizer, "vocab_file", None, None),
|
||||
Tokenizer("GPT2", GPT2TokenizerFast, GPT2Tokenizer, "vocab_file", None, [("add_prefix_space", True)]),
|
||||
]
|
||||
|
||||
def fast_align_python(self, tokenizer_r, tokenizer_p, tok_case, pretrained_name):
|
||||
# Check is_fast is set correctly
|
||||
self.assertFalse(tokenizer_p.is_fast)
|
||||
self.assertTrue(tokenizer_r.is_fast)
|
||||
|
||||
# Check that Rust and Python align
|
||||
self.assert_tokenization_python_rust_equals(tokenizer_r, tokenizer_p)
|
||||
self.assert_num_special_tokens_to_add_equal(tokenizer_r, tokenizer_p)
|
||||
self.assert_max_length_equal(tokenizer_r, tokenizer_p)
|
||||
self.assert_special_tokens_map_equal(tokenizer_r, tokenizer_p)
|
||||
self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p)
|
||||
self.assert_padding(tokenizer_r, tokenizer_p)
|
||||
|
||||
# Specific for
|
||||
kwargs = {}
|
||||
if tok_case.kwargs is not None:
|
||||
kwargs = dict(tok_case.kwargs)
|
||||
tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
|
||||
self.assert_pretokenized_inputs(tokenizer_r, tokenizer_p)
|
||||
|
||||
def assert_padding(self, tokenizer_r, tokenizer_p, max_length=15):
|
||||
# Simple input
|
||||
s = "This is a simple input"
|
||||
@@ -595,27 +810,31 @@ class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
|
||||
]
|
||||
|
||||
# Simple input tests
|
||||
self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, pad_to_max_length=True)
|
||||
self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, padding="max_length")
|
||||
|
||||
# Simple input
|
||||
self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, pad_to_max_length=True)
|
||||
self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, padding="max_length")
|
||||
|
||||
# Simple input
|
||||
self.assertRaises(ValueError, tokenizer_r.batch_encode_plus, s2, max_length=max_length, pad_to_max_length=True)
|
||||
self.assertRaises(
|
||||
ValueError, tokenizer_r.batch_encode_plus, s2, max_length=max_length, padding="max_length",
|
||||
)
|
||||
|
||||
# Pair input
|
||||
self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, pad_to_max_length=True)
|
||||
self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, padding="max_length")
|
||||
|
||||
# Pair input
|
||||
self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, pad_to_max_length=True)
|
||||
self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, padding="max_length")
|
||||
|
||||
# Pair input
|
||||
self.assertRaises(ValueError, tokenizer_r.batch_encode_plus, p2, max_length=max_length, pad_to_max_length=True)
|
||||
self.assertRaises(
|
||||
ValueError, tokenizer_r.batch_encode_plus, p2, max_length=max_length, padding="max_length",
|
||||
)
|
||||
|
||||
|
||||
class TransfoXLFastTokenizerTest(NoPaddingTokenFastTokenizerMatchingTest):
|
||||
TOKENIZERS_CLASSES = frozenset(
|
||||
[Tokenizer("TransfoXL", TransfoXLTokenizerFast, TransfoXLTokenizer, "pretrained_vocab_file", None)]
|
||||
[Tokenizer("TransfoXL", TransfoXLTokenizerFast, TransfoXLTokenizer, "pretrained_vocab_file", None, None)]
|
||||
)
|
||||
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user