[Tokenizer Utils Base] Make pad function more flexible (#9928)
* change tokenizer requirement * split line * Correct typo from list to str * improve style * make other function pretty as well * add comment * correct typo * add new test * pass tests for tok without padding token * Apply suggestions from code review
This commit is contained in:
committed by
GitHub
parent
d1b14c9b54
commit
538b3b4607
@@ -160,6 +160,38 @@ class TokenizerTesterMixin:
|
||||
# "This is a test",
|
||||
# )
|
||||
|
||||
def assert_padded_input_match(self, input_r: list, input_p: list, max_length: int, pad_token_id: int):
|
||||
# Ensure we match max_length
|
||||
self.assertEqual(len(input_r), max_length)
|
||||
self.assertEqual(len(input_p), max_length)
|
||||
|
||||
# Ensure the number of padded tokens is the same
|
||||
padded_tokens_r = list(takewhile(lambda i: i == pad_token_id, reversed(input_r)))
|
||||
padded_tokens_p = list(takewhile(lambda i: i == pad_token_id, reversed(input_p)))
|
||||
self.assertSequenceEqual(padded_tokens_r, padded_tokens_p)
|
||||
|
||||
def assert_batch_padded_input_match(
|
||||
self,
|
||||
input_r: dict,
|
||||
input_p: dict,
|
||||
max_length: int,
|
||||
pad_token_id: int,
|
||||
model_main_input_name: str = "input_ids",
|
||||
):
|
||||
for i_r in input_r.values():
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
|
||||
len(i_r[1]), max_length
|
||||
)
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
|
||||
len(i_r[1]), max_length
|
||||
)
|
||||
|
||||
for i_r, i_p in zip(input_r[model_main_input_name], input_p[model_main_input_name]):
|
||||
self.assert_padded_input_match(i_r, i_p, max_length, pad_token_id)
|
||||
|
||||
for i_r, i_p in zip(input_r["attention_mask"], input_p["attention_mask"]):
|
||||
self.assertSequenceEqual(i_r, i_p)
|
||||
|
||||
@staticmethod
|
||||
def convert_batch_encode_plus_format_to_encode_plus(batch_encode_plus_sequences):
|
||||
# Switch from batch_encode_plus format: {'input_ids': [[...], [...]], ...}
|
||||
@@ -169,6 +201,18 @@ class TokenizerTesterMixin:
|
||||
for i in range(len(batch_encode_plus_sequences["input_ids"]))
|
||||
]
|
||||
|
||||
def test_model_input_names_signature(self):
|
||||
accepted_model_main_input_names = [
|
||||
"input_ids", # nlp models
|
||||
"input_values", # speech models
|
||||
]
|
||||
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
# first name of model_input_names has to correspond to main model input name
|
||||
# to make sure `tokenizer.pad(...)` works correctly
|
||||
self.assertTrue(tokenizer.model_input_names[0] in accepted_model_main_input_names)
|
||||
|
||||
def test_rust_tokenizer_signature(self):
|
||||
if not self.test_rust_tokenizer:
|
||||
return
|
||||
@@ -2419,43 +2463,20 @@ class TokenizerTesterMixin:
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
|
||||
def assert_padded_input_match(input_r: list, input_p: list, max_length: int):
|
||||
|
||||
# Ensure we match max_length
|
||||
self.assertEqual(len(input_r), max_length)
|
||||
self.assertEqual(len(input_p), max_length)
|
||||
|
||||
# Ensure the number of padded tokens is the same
|
||||
padded_tokens_r = list(takewhile(lambda i: i == tokenizer_r.pad_token_id, reversed(input_r)))
|
||||
padded_tokens_p = list(takewhile(lambda i: i == tokenizer_p.pad_token_id, reversed(input_p)))
|
||||
self.assertSequenceEqual(padded_tokens_r, padded_tokens_p)
|
||||
|
||||
def assert_batch_padded_input_match(input_r: dict, input_p: dict, max_length: int):
|
||||
for i_r in input_r.values():
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
|
||||
len(i_r[1]), max_length
|
||||
)
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
|
||||
len(i_r[1]), max_length
|
||||
)
|
||||
|
||||
for i_r, i_p in zip(input_r["input_ids"], input_p["input_ids"]):
|
||||
assert_padded_input_match(i_r, i_p, max_length)
|
||||
|
||||
for i_r, i_p in zip(input_r["attention_mask"], input_p["attention_mask"]):
|
||||
self.assertSequenceEqual(i_r, i_p)
|
||||
self.assertEqual(tokenizer_p.pad_token_id, tokenizer_r.pad_token_id)
|
||||
pad_token_id = tokenizer_p.pad_token_id
|
||||
|
||||
# Encode - Simple input
|
||||
input_r = tokenizer_r.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
|
||||
input_p = tokenizer_p.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||||
input_r = tokenizer_r.encode("This is a simple input", max_length=max_length, padding="max_length")
|
||||
input_p = tokenizer_p.encode("This is a simple input", max_length=max_length, padding="max_length")
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||||
|
||||
input_r = tokenizer_r.encode("This is a simple input", padding="longest")
|
||||
input_p = tokenizer_p.encode("This is a simple input", padding=True)
|
||||
assert_padded_input_match(input_r, input_p, len(input_r))
|
||||
self.assert_padded_input_match(input_r, input_p, len(input_r), pad_token_id)
|
||||
|
||||
# Encode - Pair input
|
||||
input_r = tokenizer_r.encode(
|
||||
@@ -2464,17 +2485,17 @@ class TokenizerTesterMixin:
|
||||
input_p = tokenizer_p.encode(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||||
input_r = tokenizer_r.encode(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
)
|
||||
input_p = tokenizer_p.encode(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
)
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||||
input_r = tokenizer_r.encode("This is a simple input", "This is a pair", padding=True)
|
||||
input_p = tokenizer_p.encode("This is a simple input", "This is a pair", padding="longest")
|
||||
assert_padded_input_match(input_r, input_p, len(input_r))
|
||||
self.assert_padded_input_match(input_r, input_p, len(input_r), pad_token_id)
|
||||
|
||||
# Encode_plus - Simple input
|
||||
input_r = tokenizer_r.encode_plus(
|
||||
@@ -2483,7 +2504,7 @@ class TokenizerTesterMixin:
|
||||
input_p = tokenizer_p.encode_plus(
|
||||
"This is a simple input", max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
input_r = tokenizer_r.encode_plus(
|
||||
"This is a simple input", max_length=max_length, padding="max_length"
|
||||
@@ -2491,12 +2512,14 @@ class TokenizerTesterMixin:
|
||||
input_p = tokenizer_p.encode_plus(
|
||||
"This is a simple input", max_length=max_length, padding="max_length"
|
||||
)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
|
||||
input_r = tokenizer_r.encode_plus("This is a simple input", padding="longest")
|
||||
input_p = tokenizer_p.encode_plus("This is a simple input", padding=True)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
|
||||
self.assert_padded_input_match(
|
||||
input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
|
||||
)
|
||||
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
|
||||
@@ -2507,7 +2530,7 @@ class TokenizerTesterMixin:
|
||||
input_p = tokenizer_p.encode_plus(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
input_r = tokenizer_r.encode_plus(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
@@ -2515,11 +2538,13 @@ class TokenizerTesterMixin:
|
||||
input_p = tokenizer_p.encode_plus(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
input_r = tokenizer_r.encode_plus("This is a simple input", "This is a pair", padding="longest")
|
||||
input_p = tokenizer_p.encode_plus("This is a simple input", "This is a pair", padding=True)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
|
||||
self.assert_padded_input_match(
|
||||
input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
|
||||
)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
|
||||
# Batch_encode_plus - Simple input
|
||||
@@ -2533,7 +2558,7 @@ class TokenizerTesterMixin:
|
||||
max_length=max_length,
|
||||
pad_to_max_length=True,
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||||
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"],
|
||||
@@ -2545,7 +2570,7 @@ class TokenizerTesterMixin:
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||||
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"],
|
||||
@@ -2557,7 +2582,7 @@ class TokenizerTesterMixin:
|
||||
max_length=max_length,
|
||||
padding=True,
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
|
||||
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], padding="longest"
|
||||
@@ -2565,7 +2590,7 @@ class TokenizerTesterMixin:
|
||||
input_p = tokenizer_p.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], padding=True
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
|
||||
|
||||
# Batch_encode_plus - Pair input
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
@@ -2586,7 +2611,7 @@ class TokenizerTesterMixin:
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||||
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
[
|
||||
@@ -2602,7 +2627,7 @@ class TokenizerTesterMixin:
|
||||
],
|
||||
padding="longest",
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
|
||||
|
||||
# Using pad on single examples after tokenization
|
||||
input_r = tokenizer_r.encode_plus("This is a input 1")
|
||||
@@ -2611,7 +2636,9 @@ class TokenizerTesterMixin:
|
||||
input_p = tokenizer_r.encode_plus("This is a input 1")
|
||||
input_p = tokenizer_r.pad(input_p)
|
||||
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
|
||||
self.assert_padded_input_match(
|
||||
input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
|
||||
)
|
||||
|
||||
# Using pad on single examples after tokenization
|
||||
input_r = tokenizer_r.encode_plus("This is a input 1")
|
||||
@@ -2620,7 +2647,7 @@ class TokenizerTesterMixin:
|
||||
input_p = tokenizer_r.encode_plus("This is a input 1")
|
||||
input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
|
||||
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
|
||||
|
||||
# Using pad after tokenization
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
@@ -2633,7 +2660,7 @@ class TokenizerTesterMixin:
|
||||
)
|
||||
input_p = tokenizer_r.pad(input_p)
|
||||
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
|
||||
|
||||
# Using pad after tokenization
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
@@ -2646,7 +2673,41 @@ class TokenizerTesterMixin:
|
||||
)
|
||||
input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
|
||||
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||||
|
||||
def test_padding_different_model_input_name(self):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest("{} ({})".format(tokenizer.__class__.__name__, pretrained_name)):
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
self.assertEqual(tokenizer_p.pad_token_id, tokenizer_r.pad_token_id)
|
||||
pad_token_id = tokenizer_p.pad_token_id
|
||||
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a input 1", "This is a much longer input whilch should be padded"]
|
||||
)
|
||||
input_p = tokenizer_r.batch_encode_plus(
|
||||
["This is a input 1", "This is a much longer input whilch should be padded"]
|
||||
)
|
||||
|
||||
# rename encoded batch to "inputs"
|
||||
input_r["inputs"] = input_r[tokenizer_r.model_input_names[0]]
|
||||
del input_r[tokenizer_r.model_input_names[0]]
|
||||
|
||||
input_p["inputs"] = input_p[tokenizer_p.model_input_names[0]]
|
||||
del input_p[tokenizer_p.model_input_names[0]]
|
||||
|
||||
# Renaming `input_ids` to `inputs`
|
||||
tokenizer_r.model_input_names = ["inputs"] + tokenizer_r.model_input_names[1:]
|
||||
tokenizer_p.model_input_names = ["inputs"] + tokenizer_p.model_input_names[1:]
|
||||
|
||||
input_r = tokenizer_r.pad(input_r, padding="longest")
|
||||
input_p = tokenizer_r.pad(input_p, padding="longest")
|
||||
|
||||
max_length = len(input_p["inputs"][0])
|
||||
self.assert_batch_padded_input_match(
|
||||
input_r, input_p, max_length, pad_token_id, model_main_input_name="inputs"
|
||||
)
|
||||
|
||||
def test_save_pretrained(self):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
|
||||
@@ -174,3 +174,7 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
)
|
||||
|
||||
# tokenizer has no padding token
|
||||
def test_padding_different_model_input_name(self):
|
||||
pass
|
||||
|
||||
@@ -128,3 +128,7 @@ class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
)
|
||||
|
||||
# tokenizer has no padding token
|
||||
def test_padding_different_model_input_name(self):
|
||||
pass
|
||||
|
||||
@@ -107,6 +107,10 @@ class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
padding="max_length",
|
||||
)
|
||||
|
||||
# tokenizer has no padding token
|
||||
def test_padding_different_model_input_name(self):
|
||||
pass
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = ReformerTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user