Fix max_length not taken into account when using pad_to_max_length on fast tokenizers (#2961)
* enable_padding should pad up to max_length if set. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Added more testing on padding. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
@@ -76,6 +76,63 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
||||
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
|
||||
self.assert_sequence_almost_equals(input_p[key], input_r[key], threshold)
|
||||
|
||||
def assert_padding(self, tokenizer_r, tokenizer_p):
|
||||
# Simple input
|
||||
input_r = tokenizer_r.encode("This is a simple input", max_length=15, pad_to_max_length=True)
|
||||
input_p = tokenizer_p.encode("This is a simple input", max_length=15, pad_to_max_length=True)
|
||||
|
||||
self.assertSequenceEqual(input_r, input_p)
|
||||
|
||||
# Simple input
|
||||
input_r = tokenizer_r.encode_plus("This is a simple input", max_length=15, pad_to_max_length=True)
|
||||
input_p = tokenizer_p.encode_plus("This is a simple input", max_length=15, pad_to_max_length=True)
|
||||
|
||||
self.assertSequenceEqual(input_r, input_p)
|
||||
|
||||
# Simple input
|
||||
# TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding
|
||||
# input_r = tokenizer_r.batch_encode_plus(
|
||||
# ["This is a simple input 1", "This is a simple input 2"], max_length=15, pad_to_max_length=True
|
||||
# )
|
||||
# input_p = tokenizer_p.batch_encode_plus(
|
||||
# ["This is a simple input 1", "This is a simple input 2"], max_length=15, pad_to_max_length=True
|
||||
# )
|
||||
|
||||
# self.assertSequenceEqual(input_r, input_p)
|
||||
|
||||
# Pair input
|
||||
input_r = tokenizer_r.encode("This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True)
|
||||
input_p = tokenizer_p.encode("This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True)
|
||||
|
||||
self.assertSequenceEqual(input_r, input_p)
|
||||
|
||||
# Pair input
|
||||
input_r = tokenizer_r.encode_plus(
|
||||
"This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True
|
||||
)
|
||||
input_p = tokenizer_p.encode_plus(
|
||||
"This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True
|
||||
)
|
||||
|
||||
self.assertSequenceEqual(input_r, input_p)
|
||||
|
||||
# Pair input
|
||||
# TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding
|
||||
# input_r = tokenizer_r.batch_encode_plus(
|
||||
# ["This is a simple input 1", "This is a simple input 2"],
|
||||
# ["This is a simple pair 1", "This is a simple pair 2"],
|
||||
# max_length=15,
|
||||
# pad_to_max_length=True,
|
||||
# )
|
||||
# input_p = tokenizer_p.batch_encode_plus(
|
||||
# ["This is a simple input 1", "This is a simple input 2"],
|
||||
# ["This is a simple pair 1", "This is a simple pair 2"],
|
||||
# max_length=15,
|
||||
# pad_to_max_length=True,
|
||||
# )
|
||||
|
||||
# self.assertSequenceEqual(input_r, input_p)
|
||||
|
||||
def assert_add_tokens(self, tokenizer_r):
|
||||
vocab_size = tokenizer_r.vocab_size
|
||||
self.assertEqual(tokenizer_r.add_tokens(""), 0)
|
||||
@@ -239,6 +296,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
||||
# Check the number of returned files for save_vocabulary
|
||||
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
||||
|
||||
# Check for padding
|
||||
self.assert_padding(tokenizer_r, tokenizer_p)
|
||||
|
||||
@require_torch
|
||||
def test_transfoxl(self):
|
||||
for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys():
|
||||
@@ -278,6 +338,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
||||
# Check the number of returned files for save_vocabulary
|
||||
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
||||
|
||||
# Check for padding
|
||||
self.assertRaises(ValueError, self.assert_padding, tokenizer_r, tokenizer_p)
|
||||
|
||||
def test_distilbert(self):
|
||||
for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
||||
tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name)
|
||||
@@ -317,6 +380,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
||||
# Check the number of returned files for save_vocabulary
|
||||
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
||||
|
||||
# Check for padding
|
||||
self.assert_padding(tokenizer_r, tokenizer_p)
|
||||
|
||||
def test_gpt2(self):
|
||||
for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
||||
tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name)
|
||||
@@ -355,6 +421,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
||||
# Check the number of returned files for save_vocabulary
|
||||
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
||||
|
||||
# Check for padding
|
||||
self.assertRaises(ValueError, self.assert_padding, tokenizer_r, tokenizer_p)
|
||||
|
||||
def test_roberta(self):
|
||||
for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
||||
tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name)
|
||||
@@ -393,6 +462,10 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
||||
# Check the number of returned files for save_vocabulary
|
||||
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
||||
|
||||
# Check for padding
|
||||
# TODO: Re-enable this test as soon as Roberta align with the python tokenizer.
|
||||
# self.assert_padding(tokenizer_r, tokenizer_p)
|
||||
|
||||
def test_openai(self):
|
||||
for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
||||
tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name)
|
||||
@@ -431,6 +504,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
||||
# Check the number of returned files for save_vocabulary
|
||||
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
||||
|
||||
# Check for padding
|
||||
self.assertRaises(ValueError, self.assert_padding, tokenizer_r, tokenizer_p)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user