[HUGE] Refactoring tokenizers backend - padding - truncation - pre-tokenized pipeline - fast tokenizers - tests (#4510)
* Use tokenizers pre-tokenized pipeline * failing pretrokenized test * Fix is_pretokenized in python * add pretokenized tests * style and quality * better tests for batched pretokenized inputs * tokenizers clean up - new padding_strategy - split the files * [HUGE] refactoring tokenizers - padding - truncation - tests * style and quality * bump up requied tokenizers version to 0.8.0-rc1 * switched padding/truncation API - simpler better backward compat * updating tests for custom tokenizers * style and quality - tests on pad * fix QA pipeline * fix backward compatibility for max_length only * style and quality * Various cleans up - add verbose * fix tests * update docstrings * Fix tests * Docs reformatted * __call__ method documented Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -36,7 +36,7 @@ class AlbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer = AlbertTokenizer(SAMPLE_VOCAB)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "this is a test"
|
||||
output_text = "this is a test"
|
||||
return input_text, output_text
|
||||
|
||||
@@ -44,6 +44,8 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
"[UNK]",
|
||||
"[CLS]",
|
||||
"[SEP]",
|
||||
"[PAD]",
|
||||
"[MASK]",
|
||||
"want",
|
||||
"##want",
|
||||
"##ed",
|
||||
@@ -62,7 +64,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "UNwant\u00E9d,running"
|
||||
output_text = "unwanted, running"
|
||||
return input_text, output_text
|
||||
@@ -72,7 +74,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokens = tokenizer.tokenize("UNwant\u00E9d,running")
|
||||
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [9, 6, 7, 12, 10, 11])
|
||||
|
||||
def test_rust_and_python_full_tokenizers(self):
|
||||
if not self.test_rust_tokenizer:
|
||||
@@ -96,6 +98,25 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
rust_ids = rust_tokenizer.encode(sequence)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
# With lower casing
|
||||
tokenizer = self.get_tokenizer(do_lower_case=True)
|
||||
rust_tokenizer = self.get_rust_tokenizer(do_lower_case=True)
|
||||
|
||||
sequence = "UNwant\u00E9d,running"
|
||||
|
||||
tokens = tokenizer.tokenize(sequence)
|
||||
rust_tokens = rust_tokenizer.tokenize(sequence)
|
||||
self.assertListEqual(tokens, rust_tokens)
|
||||
|
||||
ids = tokenizer.encode(sequence, add_special_tokens=False)
|
||||
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
ids = tokenizer.encode(sequence)
|
||||
rust_ids = rust_tokenizer.encode(sequence)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
def test_chinese(self):
|
||||
tokenizer = BasicTokenizer()
|
||||
|
||||
|
||||
@@ -60,11 +60,26 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "こんにちは、世界。 \nこんばんは、世界。"
|
||||
output_text = "こんにちは 、 世界 。 こんばんは 、 世界 。"
|
||||
return input_text, output_text
|
||||
|
||||
def get_clean_sequence(self, tokenizer):
|
||||
input_text, output_text = self.get_input_output_texts(tokenizer)
|
||||
ids = tokenizer.encode(output_text, add_special_tokens=False)
|
||||
text = tokenizer.decode(ids, clean_up_tokenization_spaces=False)
|
||||
return text, ids
|
||||
|
||||
def test_pretokenized_inputs(self):
|
||||
pass # TODO add if relevant
|
||||
|
||||
def test_maximum_encoding_length_pair_input(self):
|
||||
pass # TODO add if relevant
|
||||
|
||||
def test_maximum_encoding_length_single_input(self):
|
||||
pass # TODO add if relevant
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = self.tokenizer_class(self.vocab_file)
|
||||
|
||||
@@ -157,11 +172,20 @@ class BertJapaneseCharacterTokenizationTest(TokenizerTesterMixin, unittest.TestC
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, subword_tokenizer_type="character", **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "こんにちは、世界。 \nこんばんは、世界。"
|
||||
output_text = "こ ん に ち は 、 世 界 。 こ ん ば ん は 、 世 界 。"
|
||||
return input_text, output_text
|
||||
|
||||
def test_pretokenized_inputs(self):
|
||||
pass # TODO add if relevant
|
||||
|
||||
def test_maximum_encoding_length_pair_input(self):
|
||||
pass # TODO add if relevant
|
||||
|
||||
def test_maximum_encoding_length_single_input(self):
|
||||
pass # TODO add if relevant
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = self.tokenizer_class(self.vocab_file, subword_tokenizer_type="character")
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -46,7 +46,7 @@ class CTRLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return CTRLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "adapt react readapt apt"
|
||||
output_text = "adapt react readapt apt"
|
||||
return input_text, output_text
|
||||
|
||||
@@ -27,7 +27,7 @@ logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilingual"]
|
||||
Tokenizer = namedtuple("Tokenizer", ["name", "rust_cls", "python_cls", "vocab_key", "filter"])
|
||||
Tokenizer = namedtuple("Tokenizer", ["name", "rust_cls", "python_cls", "vocab_key", "filter", "kwargs"])
|
||||
|
||||
|
||||
def filter_non_english(_: Tokenizer, pretrained_name: str):
|
||||
@@ -60,10 +60,10 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name)
|
||||
tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name)
|
||||
|
||||
self.fast_align_python(tokenizer_r, tokenizer_p)
|
||||
self.fast_align_python(tokenizer_r, tokenizer_p, tok_case, pretrained_name)
|
||||
self.fast_only(tokenizer_r)
|
||||
|
||||
def fast_align_python(self, tokenizer_r, tokenizer_p):
|
||||
def fast_align_python(self, tokenizer_r, tokenizer_p, tok_case, pretrained_name):
|
||||
# Check is_fast is set correctly
|
||||
self.assertFalse(tokenizer_p.is_fast)
|
||||
self.assertTrue(tokenizer_r.is_fast)
|
||||
@@ -75,6 +75,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assert_special_tokens_map_equal(tokenizer_r, tokenizer_p)
|
||||
self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p)
|
||||
self.assert_padding(tokenizer_r, tokenizer_p)
|
||||
self.assert_pretokenized_inputs(tokenizer_r, tokenizer_p)
|
||||
self.assert_create_token_type_ids(tokenizer_r, tokenizer_p)
|
||||
# TODO: enable for v3.0.0
|
||||
# self.assert_empty_output_no_special_tokens(tokenizer_r, tokenizer_p)
|
||||
@@ -90,6 +91,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assert_offsets_mapping(tokenizer_r)
|
||||
self.assert_add_special_tokens(tokenizer_r)
|
||||
self.assert_alignement_methods(tokenizer_r)
|
||||
self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
|
||||
|
||||
def assert_alignement_methods(self, tokenizer_r):
|
||||
words = ["Wonderful", "no", "inspiration", "example", "with", "subtoken"]
|
||||
@@ -169,7 +171,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assertEqual(batch_encoding.word_to_chars(0, last_word_index).end, last_char_index + 1)
|
||||
self.assertEqual(batch_encoding.word_to_chars(last_batch_index, last_word_index).end, last_char_index + 1)
|
||||
|
||||
def assert_tokenization_python_rust_equals(self, tokenizer_p, tokenizer_r):
|
||||
def assert_tokenization_python_rust_equals(self, tokenizer_r, tokenizer_p):
|
||||
# Ensure basic input match
|
||||
input_p = tokenizer_p.encode_plus(self._data)
|
||||
input_r = tokenizer_r.encode_plus(self._data)
|
||||
@@ -184,18 +186,22 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assertSequenceEqual(input_pairs_p[key], input_pairs_r[key])
|
||||
|
||||
# Ensure truncation match
|
||||
input_p = tokenizer_p.encode_plus(self._data, max_length=512)
|
||||
input_r = tokenizer_r.encode_plus(self._data, max_length=512)
|
||||
input_p = tokenizer_p.encode_plus(self._data, max_length=512, truncation=True)
|
||||
input_r = tokenizer_r.encode_plus(self._data, max_length=512, truncation=True)
|
||||
|
||||
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
|
||||
self.assertSequenceEqual(input_p[key], input_r[key])
|
||||
|
||||
# Ensure truncation with stride match
|
||||
input_p = tokenizer_p.encode_plus(self._data, max_length=512, stride=3, return_overflowing_tokens=True)
|
||||
input_r = tokenizer_r.encode_plus(self._data, max_length=512, stride=3, return_overflowing_tokens=True)
|
||||
input_p = tokenizer_p.encode_plus(
|
||||
self._data, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True
|
||||
)
|
||||
input_r = tokenizer_r.encode_plus(
|
||||
self._data, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True
|
||||
)
|
||||
|
||||
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
|
||||
self.assertSequenceEqual(input_p[key], input_r[key])
|
||||
self.assertSequenceEqual(input_p[key], input_r[key][0])
|
||||
|
||||
def assert_num_special_tokens_to_add_equal(self, tokenizer_r, tokenizer_p):
|
||||
# Check we have the same number of added_tokens for both pair and non-pair inputs.
|
||||
@@ -274,9 +280,14 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
"""
|
||||
returned_tensor = "pt" if is_torch_available() else "tf"
|
||||
|
||||
if not tokenizer.pad_token or tokenizer.pad_token_id < 0:
|
||||
return
|
||||
|
||||
tokens = tokenizer.encode_plus(
|
||||
"HuggingFace is solving NLP one commit at a time",
|
||||
max_length=6,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors=returned_tensor,
|
||||
return_overflowing_tokens=True,
|
||||
)
|
||||
@@ -288,7 +299,8 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
tokens = tokenizer.batch_encode_plus(
|
||||
["HuggingFace is solving NLP one commit at a time"],
|
||||
max_length=6,
|
||||
pad_to_max_len=True,
|
||||
padding=True,
|
||||
truncation="only_first",
|
||||
return_tensors=returned_tensor,
|
||||
return_overflowing_tokens=True,
|
||||
)
|
||||
@@ -301,7 +313,8 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
tokens = tokenizer.batch_encode_plus(
|
||||
["HuggingFace is solving NLP one commit at a time", "Very tiny input"],
|
||||
max_length=6,
|
||||
pad_to_max_len=True,
|
||||
padding=True,
|
||||
truncation="only_first",
|
||||
return_tensors=returned_tensor,
|
||||
return_overflowing_tokens=True,
|
||||
)
|
||||
@@ -310,6 +323,58 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assertEqual(len(tokens[key].shape), 2)
|
||||
self.assertEqual(tokens[key].shape[-1], 6)
|
||||
|
||||
def assert_pretokenized_inputs(self, tokenizer_r, tokenizer_p):
|
||||
# Input string
|
||||
pretokenized_input_simple = "This is a sample input".split()
|
||||
pretokenized_input_pair = "This is a sample pair".split()
|
||||
|
||||
# Test encode for pretokenized inputs
|
||||
output_r = tokenizer_r.encode(pretokenized_input_simple, is_pretokenized=True)
|
||||
output_p = tokenizer_p.encode(pretokenized_input_simple, is_pretokenized=True)
|
||||
self.assertEqual(output_p, output_r)
|
||||
|
||||
kwargs = {
|
||||
"is_pretokenized": True,
|
||||
"return_token_type_ids": True,
|
||||
"return_attention_mask": True,
|
||||
"return_overflowing_tokens": False,
|
||||
"return_special_tokens_mask": True,
|
||||
"return_offsets_mapping": False, # Not implemented in python tokenizers
|
||||
}
|
||||
# Test encode_plus for pretokenized inputs
|
||||
output_r = tokenizer_r.encode_plus(pretokenized_input_simple, **kwargs)
|
||||
output_p = tokenizer_p.encode_plus(pretokenized_input_simple, **kwargs)
|
||||
for key in output_p.keys():
|
||||
self.assertEqual(output_p[key], output_r[key])
|
||||
|
||||
# Test batch_encode_plus for pretokenized inputs
|
||||
input_batch = ([pretokenized_input_simple] * 2) + [pretokenized_input_simple + pretokenized_input_pair]
|
||||
output_r = tokenizer_r.batch_encode_plus(input_batch, **kwargs)
|
||||
output_p = tokenizer_p.batch_encode_plus(input_batch, **kwargs)
|
||||
for key in output_p.keys():
|
||||
self.assertEqual(output_p[key], output_r[key])
|
||||
|
||||
# Test encode for pretokenized inputs pairs
|
||||
output_r = tokenizer_r.encode(pretokenized_input_simple, pretokenized_input_pair, is_pretokenized=True)
|
||||
output_p = tokenizer_p.encode(pretokenized_input_simple, pretokenized_input_pair, is_pretokenized=True)
|
||||
self.assertEqual(output_p, output_r)
|
||||
|
||||
# Test encode_plus for pretokenized inputs
|
||||
output_r = tokenizer_r.encode_plus(pretokenized_input_simple, pretokenized_input_pair, **kwargs)
|
||||
output_p = tokenizer_p.encode_plus(pretokenized_input_simple, pretokenized_input_pair, **kwargs)
|
||||
for key in output_p.keys():
|
||||
self.assertEqual(output_p[key], output_r[key])
|
||||
|
||||
# Test batch_encode_plus for pretokenized inputs
|
||||
input_batch_pair = ([pretokenized_input_simple, pretokenized_input_pair] * 2) + [
|
||||
pretokenized_input_simple + pretokenized_input_pair,
|
||||
pretokenized_input_pair,
|
||||
]
|
||||
output_r = tokenizer_r.batch_encode_plus(input_batch_pair, **kwargs)
|
||||
output_p = tokenizer_p.batch_encode_plus(input_batch_pair, **kwargs)
|
||||
for key in output_p.keys():
|
||||
self.assertEqual(output_p[key], output_r[key])
|
||||
|
||||
def assert_create_token_type_ids(self, tokenizer_r, tokenizer_p):
|
||||
input_simple = [1, 2, 3]
|
||||
input_pair = [1, 2, 3]
|
||||
@@ -357,17 +422,22 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
def assert_padded_input_match(input_r: list, input_p: list, max_length: int):
|
||||
|
||||
# Ensure we match max_length
|
||||
self.assertEqual(len(input_r), max_length), self.assertEqual(len(input_p), max_length)
|
||||
self.assertEqual(len(input_r), max_length)
|
||||
self.assertEqual(len(input_p), max_length)
|
||||
|
||||
# Ensure the number of padded tokens is the same
|
||||
padded_tokens_r = list(takewhile(lambda i: i == tokenizer_r.pad_token_id, reversed(input_r)))
|
||||
padded_tokens_p = list(takewhile(lambda i: i == tokenizer_p.pad_token_id, reversed(input_p)))
|
||||
self.assertSequenceEqual(padded_tokens_r, padded_tokens_p)
|
||||
|
||||
def assert_batch_padded_input_match(input_r: dict, input_p: dict):
|
||||
def assert_batch_padded_input_match(input_r: dict, input_p: dict, max_length: int):
|
||||
for i_r in input_r.values():
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), 15), self.assertEqual(len(i_r[1]), 15)
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), 15), self.assertEqual(len(i_r[1]), 15)
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
|
||||
len(i_r[1]), max_length
|
||||
)
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
|
||||
len(i_r[1]), max_length
|
||||
)
|
||||
|
||||
for i_r, i_p in zip(input_r["input_ids"], input_p["input_ids"]):
|
||||
assert_padded_input_match(i_r, i_p, max_length)
|
||||
@@ -375,12 +445,19 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
for i_r, i_p in zip(input_r["attention_mask"], input_p["attention_mask"]):
|
||||
self.assertSequenceEqual(i_r, i_p)
|
||||
|
||||
# Simple input
|
||||
# Encode - Simple input
|
||||
input_r = tokenizer_r.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
|
||||
input_p = tokenizer_p.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
input_r = tokenizer_r.encode("This is a simple input", max_length=max_length, padding="max_length")
|
||||
input_p = tokenizer_p.encode("This is a simple input", max_length=max_length, padding="max_length")
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
|
||||
# Pair input
|
||||
input_r = tokenizer_r.encode("This is a simple input", padding="longest")
|
||||
input_p = tokenizer_p.encode("This is a simple input", padding=True)
|
||||
assert_padded_input_match(input_r, input_p, len(input_r))
|
||||
|
||||
# Encode - Pair input
|
||||
input_r = tokenizer_r.encode(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
@@ -388,14 +465,34 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
input_r = tokenizer_r.encode(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
)
|
||||
input_p = tokenizer_p.encode(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
)
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
input_r = tokenizer_r.encode("This is a simple input", "This is a pair", padding=True)
|
||||
input_p = tokenizer_p.encode("This is a simple input", "This is a pair", padding="longest")
|
||||
assert_padded_input_match(input_r, input_p, len(input_r))
|
||||
|
||||
# Simple input
|
||||
# Encode_plus - Simple input
|
||||
input_r = tokenizer_r.encode_plus("This is a simple input", max_length=max_length, pad_to_max_length=True)
|
||||
input_p = tokenizer_p.encode_plus("This is a simple input", max_length=max_length, pad_to_max_length=True)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
input_r = tokenizer_r.encode_plus("This is a simple input", max_length=max_length, padding="max_length")
|
||||
input_p = tokenizer_p.encode_plus("This is a simple input", max_length=max_length, padding="max_length")
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
|
||||
# Pair input
|
||||
input_r = tokenizer_r.encode_plus("This is a simple input", padding="longest")
|
||||
input_p = tokenizer_p.encode_plus("This is a simple input", padding=True)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
|
||||
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
|
||||
# Encode_plus - Pair input
|
||||
input_r = tokenizer_r.encode_plus(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
@@ -404,34 +501,130 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
input_r = tokenizer_r.encode_plus(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
)
|
||||
input_p = tokenizer_p.encode_plus(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
input_r = tokenizer_r.encode_plus("This is a simple input", "This is a pair", padding="longest")
|
||||
input_p = tokenizer_p.encode_plus("This is a simple input", "This is a pair", padding=True)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
|
||||
# Simple input
|
||||
# Batch_encode_plus - Simple input
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
input_p = tokenizer_p.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p)
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
|
||||
# Pair input
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, padding="max_length",
|
||||
)
|
||||
input_p = tokenizer_p.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, padding="max_length",
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, padding="longest",
|
||||
)
|
||||
input_p = tokenizer_p.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, padding=True,
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], padding="longest"
|
||||
)
|
||||
input_p = tokenizer_p.batch_encode_plus(["This is a simple input 1", "This is a simple input 2"], padding=True)
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
|
||||
# Batch_encode_plus - Pair input
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
[
|
||||
("This is a simple input 1", "This is a simple input 2"),
|
||||
("This is a simple pair 1", "This is a simple pair 2"),
|
||||
],
|
||||
max_length=15,
|
||||
pad_to_max_length=True,
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
)
|
||||
input_p = tokenizer_p.batch_encode_plus(
|
||||
[
|
||||
("This is a simple input 1", "This is a simple input 2"),
|
||||
("This is a simple pair 1", "This is a simple pair 2"),
|
||||
],
|
||||
max_length=15,
|
||||
pad_to_max_length=True,
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p)
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
[
|
||||
("This is a simple input 1", "This is a simple input 2"),
|
||||
("This is a simple pair 1", "This is a simple pair 2"),
|
||||
],
|
||||
padding=True,
|
||||
)
|
||||
input_p = tokenizer_p.batch_encode_plus(
|
||||
[
|
||||
("This is a simple input 1", "This is a simple input 2"),
|
||||
("This is a simple pair 1", "This is a simple pair 2"),
|
||||
],
|
||||
padding="longest",
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
|
||||
# Using pad on single examples after tokenization
|
||||
input_r = tokenizer_r.encode_plus("This is a input 1")
|
||||
input_r = tokenizer_r.pad(input_r)
|
||||
|
||||
input_p = tokenizer_r.encode_plus("This is a input 1")
|
||||
input_p = tokenizer_r.pad(input_p)
|
||||
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
|
||||
|
||||
# Using pad on single examples after tokenization
|
||||
input_r = tokenizer_r.encode_plus("This is a input 1")
|
||||
input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length")
|
||||
|
||||
input_p = tokenizer_r.encode_plus("This is a input 1")
|
||||
input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
|
||||
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
|
||||
# Using pad after tokenization
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a input 1", "This is a much longer input whilch should be padded"]
|
||||
)
|
||||
input_r = tokenizer_r.pad(input_r)
|
||||
|
||||
input_p = tokenizer_r.batch_encode_plus(
|
||||
["This is a input 1", "This is a much longer input whilch should be padded"]
|
||||
)
|
||||
input_p = tokenizer_r.pad(input_p)
|
||||
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
|
||||
# Using pad after tokenization
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a input 1", "This is a much longer input whilch should be padded"]
|
||||
)
|
||||
input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length")
|
||||
|
||||
input_p = tokenizer_r.batch_encode_plus(
|
||||
["This is a input 1", "This is a much longer input whilch should be padded"]
|
||||
)
|
||||
input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
|
||||
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
|
||||
def assert_save_pretrained(self, tokenizer_r, tokenizer_p):
|
||||
# Checks it save with the same files
|
||||
@@ -503,8 +696,10 @@ class WordPieceFastTokenizerTest(CommonFastTokenizerTest):
|
||||
|
||||
TOKENIZERS_CLASSES = frozenset(
|
||||
[
|
||||
Tokenizer("Bert", BertTokenizerFast, BertTokenizer, "vocab_file", filter_non_english),
|
||||
Tokenizer("DistilBert", DistilBertTokenizerFast, DistilBertTokenizer, "vocab_file", filter_non_english),
|
||||
Tokenizer("Bert", BertTokenizerFast, BertTokenizer, "vocab_file", filter_non_english, None),
|
||||
Tokenizer(
|
||||
"DistilBert", DistilBertTokenizerFast, DistilBertTokenizer, "vocab_file", filter_non_english, None
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -552,7 +747,7 @@ class WordPieceFastTokenizerTest(CommonFastTokenizerTest):
|
||||
|
||||
class RobertaFastTokenizerTest(CommonFastTokenizerTest):
|
||||
TOKENIZERS_CLASSES = frozenset(
|
||||
[Tokenizer("Roberta", RobertaTokenizerFast, RobertaTokenizer, "vocab_file", filter_roberta_detectors)]
|
||||
[Tokenizer("Roberta", RobertaTokenizerFast, RobertaTokenizer, "vocab_file", filter_roberta_detectors, None)]
|
||||
)
|
||||
|
||||
def assert_embeded_special_tokens(self, tokenizer_r, tokenizer_p):
|
||||
@@ -580,10 +775,30 @@ class RobertaFastTokenizerTest(CommonFastTokenizerTest):
|
||||
|
||||
class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
|
||||
TOKENIZERS_CLASSES = [
|
||||
Tokenizer("OpenAI GPT", OpenAIGPTTokenizerFast, OpenAIGPTTokenizer, "vocab_file", None),
|
||||
Tokenizer("GPT2", GPT2TokenizerFast, GPT2Tokenizer, "vocab_file", None),
|
||||
Tokenizer("OpenAI GPT", OpenAIGPTTokenizerFast, OpenAIGPTTokenizer, "vocab_file", None, None),
|
||||
Tokenizer("GPT2", GPT2TokenizerFast, GPT2Tokenizer, "vocab_file", None, [("add_prefix_space", True)]),
|
||||
]
|
||||
|
||||
def fast_align_python(self, tokenizer_r, tokenizer_p, tok_case, pretrained_name):
|
||||
# Check is_fast is set correctly
|
||||
self.assertFalse(tokenizer_p.is_fast)
|
||||
self.assertTrue(tokenizer_r.is_fast)
|
||||
|
||||
# Check that Rust and Python align
|
||||
self.assert_tokenization_python_rust_equals(tokenizer_r, tokenizer_p)
|
||||
self.assert_num_special_tokens_to_add_equal(tokenizer_r, tokenizer_p)
|
||||
self.assert_max_length_equal(tokenizer_r, tokenizer_p)
|
||||
self.assert_special_tokens_map_equal(tokenizer_r, tokenizer_p)
|
||||
self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p)
|
||||
self.assert_padding(tokenizer_r, tokenizer_p)
|
||||
|
||||
# Specific for
|
||||
kwargs = {}
|
||||
if tok_case.kwargs is not None:
|
||||
kwargs = dict(tok_case.kwargs)
|
||||
tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
|
||||
self.assert_pretokenized_inputs(tokenizer_r, tokenizer_p)
|
||||
|
||||
def assert_padding(self, tokenizer_r, tokenizer_p, max_length=15):
|
||||
# Simple input
|
||||
s = "This is a simple input"
|
||||
@@ -595,27 +810,31 @@ class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
|
||||
]
|
||||
|
||||
# Simple input tests
|
||||
self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, pad_to_max_length=True)
|
||||
self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, padding="max_length")
|
||||
|
||||
# Simple input
|
||||
self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, pad_to_max_length=True)
|
||||
self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, padding="max_length")
|
||||
|
||||
# Simple input
|
||||
self.assertRaises(ValueError, tokenizer_r.batch_encode_plus, s2, max_length=max_length, pad_to_max_length=True)
|
||||
self.assertRaises(
|
||||
ValueError, tokenizer_r.batch_encode_plus, s2, max_length=max_length, padding="max_length",
|
||||
)
|
||||
|
||||
# Pair input
|
||||
self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, pad_to_max_length=True)
|
||||
self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, padding="max_length")
|
||||
|
||||
# Pair input
|
||||
self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, pad_to_max_length=True)
|
||||
self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, padding="max_length")
|
||||
|
||||
# Pair input
|
||||
self.assertRaises(ValueError, tokenizer_r.batch_encode_plus, p2, max_length=max_length, pad_to_max_length=True)
|
||||
self.assertRaises(
|
||||
ValueError, tokenizer_r.batch_encode_plus, p2, max_length=max_length, padding="max_length",
|
||||
)
|
||||
|
||||
|
||||
class TransfoXLFastTokenizerTest(NoPaddingTokenFastTokenizerMatchingTest):
|
||||
TOKENIZERS_CLASSES = frozenset(
|
||||
[Tokenizer("TransfoXL", TransfoXLTokenizerFast, TransfoXLTokenizer, "pretrained_vocab_file", None)]
|
||||
[Tokenizer("TransfoXL", TransfoXLTokenizerFast, TransfoXLTokenizer, "pretrained_vocab_file", None, None)]
|
||||
)
|
||||
|
||||
@require_torch
|
||||
|
||||
@@ -53,6 +53,7 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
"\u0120newer",
|
||||
"\u0120wider",
|
||||
"<unk>",
|
||||
"<|endoftext|>",
|
||||
]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
|
||||
@@ -73,7 +74,7 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return GPT2TokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "lower newer"
|
||||
output_text = "lower newer"
|
||||
return input_text, output_text
|
||||
@@ -118,3 +119,8 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
input_tokens = tokens + [rust_tokenizer.unk_token]
|
||||
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
|
||||
self.assertListEqual(rust_tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
def test_pretokenized_inputs(self, *args, **kwargs):
|
||||
# It's very difficult to mix/test pretokenization with byte-level
|
||||
# And get both GPT2 and Roberta to work at the same time (mostly an issue of adding a space before the string)
|
||||
pass
|
||||
|
||||
@@ -51,10 +51,10 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer = MarianTokenizer.from_pretrained(self.tmpdirname)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, max_len=None, **kwargs) -> MarianTokenizer:
|
||||
return MarianTokenizer.from_pretrained(self.tmpdirname, model_max_length=max_len, **kwargs)
|
||||
def get_tokenizer(self, **kwargs) -> MarianTokenizer:
|
||||
return MarianTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
return (
|
||||
"This is a test",
|
||||
"This is a test",
|
||||
|
||||
@@ -64,7 +64,7 @@ class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
return "lower newer", "lower newer"
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
|
||||
@@ -18,7 +18,7 @@ import json
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.tokenization_roberta import VOCAB_FILES_NAMES, RobertaTokenizer
|
||||
from transformers.tokenization_roberta import VOCAB_FILES_NAMES, RobertaTokenizer, RobertaTokenizerFast
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
from .utils import slow
|
||||
@@ -68,7 +68,11 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return RobertaTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "lower newer"
|
||||
output_text = "lower newer"
|
||||
return input_text, output_text
|
||||
|
||||
@@ -56,7 +56,7 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
kwargs["lower_case"] = True
|
||||
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "<unk> UNwanted , running"
|
||||
output_text = "<unk> unwanted, running"
|
||||
return input_text, output_text
|
||||
|
||||
@@ -65,7 +65,7 @@ class XLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "lower newer"
|
||||
output_text = "lower newer"
|
||||
return input_text, output_text
|
||||
|
||||
Reference in New Issue
Block a user