[Tokenizer Utils Base] Make pad function more flexible (#9928)

* change tokenizer requirement

* split line

* Correct typo from list to str

* improve style

* make other function pretty as well

* add comment

* correct typo

* add new test

* pass tests for tok without padding token

* Apply suggestions from code review
This commit is contained in:
Patrick von Platen
2021-02-02 10:35:27 +03:00
committed by GitHub
parent d1b14c9b54
commit 538b3b4607
40 changed files with 187 additions and 107 deletions

View File

@@ -160,6 +160,38 @@ class TokenizerTesterMixin:
# "This is a test",
# )
def assert_padded_input_match(self, input_r: list, input_p: list, max_length: int, pad_token_id: int):
# Ensure we match max_length
self.assertEqual(len(input_r), max_length)
self.assertEqual(len(input_p), max_length)
# Ensure the number of padded tokens is the same
padded_tokens_r = list(takewhile(lambda i: i == pad_token_id, reversed(input_r)))
padded_tokens_p = list(takewhile(lambda i: i == pad_token_id, reversed(input_p)))
self.assertSequenceEqual(padded_tokens_r, padded_tokens_p)
def assert_batch_padded_input_match(
self,
input_r: dict,
input_p: dict,
max_length: int,
pad_token_id: int,
model_main_input_name: str = "input_ids",
):
for i_r in input_r.values():
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
len(i_r[1]), max_length
)
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
len(i_r[1]), max_length
)
for i_r, i_p in zip(input_r[model_main_input_name], input_p[model_main_input_name]):
self.assert_padded_input_match(i_r, i_p, max_length, pad_token_id)
for i_r, i_p in zip(input_r["attention_mask"], input_p["attention_mask"]):
self.assertSequenceEqual(i_r, i_p)
@staticmethod
def convert_batch_encode_plus_format_to_encode_plus(batch_encode_plus_sequences):
# Switch from batch_encode_plus format: {'input_ids': [[...], [...]], ...}
@@ -169,6 +201,18 @@ class TokenizerTesterMixin:
for i in range(len(batch_encode_plus_sequences["input_ids"]))
]
def test_model_input_names_signature(self):
accepted_model_main_input_names = [
"input_ids", # nlp models
"input_values", # speech models
]
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
# first name of model_input_names has to correspond to main model input name
# to make sure `tokenizer.pad(...)` works correctly
self.assertTrue(tokenizer.model_input_names[0] in accepted_model_main_input_names)
def test_rust_tokenizer_signature(self):
if not self.test_rust_tokenizer:
return
@@ -2419,43 +2463,20 @@ class TokenizerTesterMixin:
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
def assert_padded_input_match(input_r: list, input_p: list, max_length: int):
# Ensure we match max_length
self.assertEqual(len(input_r), max_length)
self.assertEqual(len(input_p), max_length)
# Ensure the number of padded tokens is the same
padded_tokens_r = list(takewhile(lambda i: i == tokenizer_r.pad_token_id, reversed(input_r)))
padded_tokens_p = list(takewhile(lambda i: i == tokenizer_p.pad_token_id, reversed(input_p)))
self.assertSequenceEqual(padded_tokens_r, padded_tokens_p)
def assert_batch_padded_input_match(input_r: dict, input_p: dict, max_length: int):
for i_r in input_r.values():
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
len(i_r[1]), max_length
)
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
len(i_r[1]), max_length
)
for i_r, i_p in zip(input_r["input_ids"], input_p["input_ids"]):
assert_padded_input_match(i_r, i_p, max_length)
for i_r, i_p in zip(input_r["attention_mask"], input_p["attention_mask"]):
self.assertSequenceEqual(i_r, i_p)
self.assertEqual(tokenizer_p.pad_token_id, tokenizer_r.pad_token_id)
pad_token_id = tokenizer_p.pad_token_id
# Encode - Simple input
input_r = tokenizer_r.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
input_p = tokenizer_p.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
assert_padded_input_match(input_r, input_p, max_length)
self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
input_r = tokenizer_r.encode("This is a simple input", max_length=max_length, padding="max_length")
input_p = tokenizer_p.encode("This is a simple input", max_length=max_length, padding="max_length")
assert_padded_input_match(input_r, input_p, max_length)
self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
input_r = tokenizer_r.encode("This is a simple input", padding="longest")
input_p = tokenizer_p.encode("This is a simple input", padding=True)
assert_padded_input_match(input_r, input_p, len(input_r))
self.assert_padded_input_match(input_r, input_p, len(input_r), pad_token_id)
# Encode - Pair input
input_r = tokenizer_r.encode(
@@ -2464,17 +2485,17 @@ class TokenizerTesterMixin:
input_p = tokenizer_p.encode(
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
)
assert_padded_input_match(input_r, input_p, max_length)
self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
input_r = tokenizer_r.encode(
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
)
input_p = tokenizer_p.encode(
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
)
assert_padded_input_match(input_r, input_p, max_length)
self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
input_r = tokenizer_r.encode("This is a simple input", "This is a pair", padding=True)
input_p = tokenizer_p.encode("This is a simple input", "This is a pair", padding="longest")
assert_padded_input_match(input_r, input_p, len(input_r))
self.assert_padded_input_match(input_r, input_p, len(input_r), pad_token_id)
# Encode_plus - Simple input
input_r = tokenizer_r.encode_plus(
@@ -2483,7 +2504,7 @@ class TokenizerTesterMixin:
input_p = tokenizer_p.encode_plus(
"This is a simple input", max_length=max_length, pad_to_max_length=True
)
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
input_r = tokenizer_r.encode_plus(
"This is a simple input", max_length=max_length, padding="max_length"
@@ -2491,12 +2512,14 @@ class TokenizerTesterMixin:
input_p = tokenizer_p.encode_plus(
"This is a simple input", max_length=max_length, padding="max_length"
)
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
input_r = tokenizer_r.encode_plus("This is a simple input", padding="longest")
input_p = tokenizer_p.encode_plus("This is a simple input", padding=True)
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
self.assert_padded_input_match(
input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
)
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
@@ -2507,7 +2530,7 @@ class TokenizerTesterMixin:
input_p = tokenizer_p.encode_plus(
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
)
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
input_r = tokenizer_r.encode_plus(
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
@@ -2515,11 +2538,13 @@ class TokenizerTesterMixin:
input_p = tokenizer_p.encode_plus(
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
)
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
input_r = tokenizer_r.encode_plus("This is a simple input", "This is a pair", padding="longest")
input_p = tokenizer_p.encode_plus("This is a simple input", "This is a pair", padding=True)
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
self.assert_padded_input_match(
input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
)
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
# Batch_encode_plus - Simple input
@@ -2533,7 +2558,7 @@ class TokenizerTesterMixin:
max_length=max_length,
pad_to_max_length=True,
)
assert_batch_padded_input_match(input_r, input_p, max_length)
self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
input_r = tokenizer_r.batch_encode_plus(
["This is a simple input 1", "This is a simple input 2"],
@@ -2545,7 +2570,7 @@ class TokenizerTesterMixin:
max_length=max_length,
padding="max_length",
)
assert_batch_padded_input_match(input_r, input_p, max_length)
self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
input_r = tokenizer_r.batch_encode_plus(
["This is a simple input 1", "This is a simple input 2"],
@@ -2557,7 +2582,7 @@ class TokenizerTesterMixin:
max_length=max_length,
padding=True,
)
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
input_r = tokenizer_r.batch_encode_plus(
["This is a simple input 1", "This is a simple input 2"], padding="longest"
@@ -2565,7 +2590,7 @@ class TokenizerTesterMixin:
input_p = tokenizer_p.batch_encode_plus(
["This is a simple input 1", "This is a simple input 2"], padding=True
)
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
# Batch_encode_plus - Pair input
input_r = tokenizer_r.batch_encode_plus(
@@ -2586,7 +2611,7 @@ class TokenizerTesterMixin:
truncation=True,
padding="max_length",
)
assert_batch_padded_input_match(input_r, input_p, max_length)
self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
input_r = tokenizer_r.batch_encode_plus(
[
@@ -2602,7 +2627,7 @@ class TokenizerTesterMixin:
],
padding="longest",
)
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
# Using pad on single examples after tokenization
input_r = tokenizer_r.encode_plus("This is a input 1")
@@ -2611,7 +2636,9 @@ class TokenizerTesterMixin:
input_p = tokenizer_r.encode_plus("This is a input 1")
input_p = tokenizer_r.pad(input_p)
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
self.assert_padded_input_match(
input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
)
# Using pad on single examples after tokenization
input_r = tokenizer_r.encode_plus("This is a input 1")
@@ -2620,7 +2647,7 @@ class TokenizerTesterMixin:
input_p = tokenizer_r.encode_plus("This is a input 1")
input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
# Using pad after tokenization
input_r = tokenizer_r.batch_encode_plus(
@@ -2633,7 +2660,7 @@ class TokenizerTesterMixin:
)
input_p = tokenizer_r.pad(input_p)
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
# Using pad after tokenization
input_r = tokenizer_r.batch_encode_plus(
@@ -2646,7 +2673,41 @@ class TokenizerTesterMixin:
)
input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
assert_batch_padded_input_match(input_r, input_p, max_length)
self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
def test_padding_different_model_input_name(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest("{} ({})".format(tokenizer.__class__.__name__, pretrained_name)):
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
self.assertEqual(tokenizer_p.pad_token_id, tokenizer_r.pad_token_id)
pad_token_id = tokenizer_p.pad_token_id
input_r = tokenizer_r.batch_encode_plus(
["This is a input 1", "This is a much longer input whilch should be padded"]
)
input_p = tokenizer_r.batch_encode_plus(
["This is a input 1", "This is a much longer input whilch should be padded"]
)
# rename encoded batch to "inputs"
input_r["inputs"] = input_r[tokenizer_r.model_input_names[0]]
del input_r[tokenizer_r.model_input_names[0]]
input_p["inputs"] = input_p[tokenizer_p.model_input_names[0]]
del input_p[tokenizer_p.model_input_names[0]]
# Renaming `input_ids` to `inputs`
tokenizer_r.model_input_names = ["inputs"] + tokenizer_r.model_input_names[1:]
tokenizer_p.model_input_names = ["inputs"] + tokenizer_p.model_input_names[1:]
input_r = tokenizer_r.pad(input_r, padding="longest")
input_p = tokenizer_r.pad(input_p, padding="longest")
max_length = len(input_p["inputs"][0])
self.assert_batch_padded_input_match(
input_r, input_p, max_length, pad_token_id, model_main_input_name="inputs"
)
def test_save_pretrained(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:

View File

@@ -174,3 +174,7 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
max_length=max_length,
padding="max_length",
)
# tokenizer has no padding token
def test_padding_different_model_input_name(self):
pass

View File

@@ -128,3 +128,7 @@ class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
max_length=max_length,
padding="max_length",
)
# tokenizer has no padding token
def test_padding_different_model_input_name(self):
pass

View File

@@ -107,6 +107,10 @@ class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
padding="max_length",
)
# tokenizer has no padding token
def test_padding_different_model_input_name(self):
pass
def test_full_tokenizer(self):
tokenizer = ReformerTokenizer(SAMPLE_VOCAB, keep_accents=True)