Add pad_to_multiple_of on tokenizers (reimport) (#5054)

* Add new parameter `pad_to_multiple_of` on tokenizers.

* unittest for pad_to_multiple_of

* Add .name when logging enum.

* Fix missing .items() on dict in tests.

* Add special check + warning if the tokenizer doesn't have proper pad_token.

* Use the correct logger format specifier.

* Ensure tokenizer with no pad_token do not modify the underlying padding strategy.

* Skip test if tokenizer doesn't have pad_token

* Fix RobertaTokenizer on empty input

* Format.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* fix and updating to simpler API

Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
Funtowicz Morgan
2020-06-26 11:55:57 +02:00
committed by GitHub
parent 7cc15bdd96
commit 135791e8ef
5 changed files with 117 additions and 18 deletions

View File

@@ -883,6 +883,40 @@ class TokenizerTesterMixin:
assert sequence_length == padded_sequence_right_length
assert encoded_sequence == padded_sequence_right
def test_padding_to_multiple_of(self):
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
if tokenizer.pad_token is None:
self.skipTest("No padding token.")
else:
with self.subTest(f"{tokenizer.__class__.__name__}"):
empty_tokens = tokenizer("", padding=True, pad_to_multiple_of=8)
normal_tokens = tokenizer("This is a sample input", padding=True, pad_to_multiple_of=8)
for key, value in empty_tokens.items():
self.assertEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key))
for key, value in normal_tokens.items():
self.assertEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key))
normal_tokens = tokenizer("This", pad_to_multiple_of=8)
for key, value in normal_tokens.items():
self.assertNotEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key))
# Should also work with truncation
normal_tokens = tokenizer("This", padding=True, truncation=True, pad_to_multiple_of=8)
for key, value in normal_tokens.items():
self.assertEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key))
# truncation to something which is not a multiple of pad_to_multiple_of raises an error
self.assertRaises(
ValueError,
tokenizer.__call__,
"This",
padding=True,
truncation=True,
max_length=12,
pad_to_multiple_of=8,
)
def test_encode_plus_with_padding(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers: