Add pad_to_multiple_of on tokenizers (reimport) (#5054)
* Add new parameter `pad_to_multiple_of` on tokenizers. * unittest for pad_to_multiple_of * Add .name when logging enum. * Fix missing .items() on dict in tests. * Add special check + warning if the tokenizer doesn't have proper pad_token. * Use the correct logger format specifier. * Ensure tokenizer with no pad_token do not modify the underlying padding strategy. * Skip test if tokenizer doesn't have pad_token * Fix RobertaTokenizer on empty input * Format. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * fix and updating to simpler API Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
@@ -883,6 +883,40 @@ class TokenizerTesterMixin:
|
||||
assert sequence_length == padded_sequence_right_length
|
||||
assert encoded_sequence == padded_sequence_right
|
||||
|
||||
def test_padding_to_multiple_of(self):
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
if tokenizer.pad_token is None:
|
||||
self.skipTest("No padding token.")
|
||||
else:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
empty_tokens = tokenizer("", padding=True, pad_to_multiple_of=8)
|
||||
normal_tokens = tokenizer("This is a sample input", padding=True, pad_to_multiple_of=8)
|
||||
for key, value in empty_tokens.items():
|
||||
self.assertEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key))
|
||||
for key, value in normal_tokens.items():
|
||||
self.assertEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key))
|
||||
|
||||
normal_tokens = tokenizer("This", pad_to_multiple_of=8)
|
||||
for key, value in normal_tokens.items():
|
||||
self.assertNotEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key))
|
||||
|
||||
# Should also work with truncation
|
||||
normal_tokens = tokenizer("This", padding=True, truncation=True, pad_to_multiple_of=8)
|
||||
for key, value in normal_tokens.items():
|
||||
self.assertEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key))
|
||||
|
||||
# truncation to something which is not a multiple of pad_to_multiple_of raises an error
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
tokenizer.__call__,
|
||||
"This",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=12,
|
||||
pad_to_multiple_of=8,
|
||||
)
|
||||
|
||||
def test_encode_plus_with_padding(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
for tokenizer in tokenizers:
|
||||
|
||||
Reference in New Issue
Block a user