Fix max_length not taken into account when using pad_to_max_length on fast tokenizers (#2961)
* enable_padding should pad up to max_length if set. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Added more testing on padding. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
@@ -79,7 +79,7 @@ def truncate_and_pad(
|
|||||||
|
|
||||||
if pad_to_max_length and (pad_token and pad_token_id >= 0):
|
if pad_to_max_length and (pad_token and pad_token_id >= 0):
|
||||||
tokenizer.enable_padding(
|
tokenizer.enable_padding(
|
||||||
max_length=None,
|
max_length=max_length,
|
||||||
direction=padding_side,
|
direction=padding_side,
|
||||||
pad_id=pad_token_id,
|
pad_id=pad_token_id,
|
||||||
pad_type_id=pad_token_type_id,
|
pad_type_id=pad_token_type_id,
|
||||||
|
|||||||
@@ -76,6 +76,63 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|||||||
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
|
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
|
||||||
self.assert_sequence_almost_equals(input_p[key], input_r[key], threshold)
|
self.assert_sequence_almost_equals(input_p[key], input_r[key], threshold)
|
||||||
|
|
||||||
|
def assert_padding(self, tokenizer_r, tokenizer_p):
|
||||||
|
# Simple input
|
||||||
|
input_r = tokenizer_r.encode("This is a simple input", max_length=15, pad_to_max_length=True)
|
||||||
|
input_p = tokenizer_p.encode("This is a simple input", max_length=15, pad_to_max_length=True)
|
||||||
|
|
||||||
|
self.assertSequenceEqual(input_r, input_p)
|
||||||
|
|
||||||
|
# Simple input
|
||||||
|
input_r = tokenizer_r.encode_plus("This is a simple input", max_length=15, pad_to_max_length=True)
|
||||||
|
input_p = tokenizer_p.encode_plus("This is a simple input", max_length=15, pad_to_max_length=True)
|
||||||
|
|
||||||
|
self.assertSequenceEqual(input_r, input_p)
|
||||||
|
|
||||||
|
# Simple input
|
||||||
|
# TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding
|
||||||
|
# input_r = tokenizer_r.batch_encode_plus(
|
||||||
|
# ["This is a simple input 1", "This is a simple input 2"], max_length=15, pad_to_max_length=True
|
||||||
|
# )
|
||||||
|
# input_p = tokenizer_p.batch_encode_plus(
|
||||||
|
# ["This is a simple input 1", "This is a simple input 2"], max_length=15, pad_to_max_length=True
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self.assertSequenceEqual(input_r, input_p)
|
||||||
|
|
||||||
|
# Pair input
|
||||||
|
input_r = tokenizer_r.encode("This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True)
|
||||||
|
input_p = tokenizer_p.encode("This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True)
|
||||||
|
|
||||||
|
self.assertSequenceEqual(input_r, input_p)
|
||||||
|
|
||||||
|
# Pair input
|
||||||
|
input_r = tokenizer_r.encode_plus(
|
||||||
|
"This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True
|
||||||
|
)
|
||||||
|
input_p = tokenizer_p.encode_plus(
|
||||||
|
"This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertSequenceEqual(input_r, input_p)
|
||||||
|
|
||||||
|
# Pair input
|
||||||
|
# TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding
|
||||||
|
# input_r = tokenizer_r.batch_encode_plus(
|
||||||
|
# ["This is a simple input 1", "This is a simple input 2"],
|
||||||
|
# ["This is a simple pair 1", "This is a simple pair 2"],
|
||||||
|
# max_length=15,
|
||||||
|
# pad_to_max_length=True,
|
||||||
|
# )
|
||||||
|
# input_p = tokenizer_p.batch_encode_plus(
|
||||||
|
# ["This is a simple input 1", "This is a simple input 2"],
|
||||||
|
# ["This is a simple pair 1", "This is a simple pair 2"],
|
||||||
|
# max_length=15,
|
||||||
|
# pad_to_max_length=True,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self.assertSequenceEqual(input_r, input_p)
|
||||||
|
|
||||||
def assert_add_tokens(self, tokenizer_r):
|
def assert_add_tokens(self, tokenizer_r):
|
||||||
vocab_size = tokenizer_r.vocab_size
|
vocab_size = tokenizer_r.vocab_size
|
||||||
self.assertEqual(tokenizer_r.add_tokens(""), 0)
|
self.assertEqual(tokenizer_r.add_tokens(""), 0)
|
||||||
@@ -239,6 +296,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|||||||
# Check the number of returned files for save_vocabulary
|
# Check the number of returned files for save_vocabulary
|
||||||
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
||||||
|
|
||||||
|
# Check for padding
|
||||||
|
self.assert_padding(tokenizer_r, tokenizer_p)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_transfoxl(self):
|
def test_transfoxl(self):
|
||||||
for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys():
|
for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys():
|
||||||
@@ -278,6 +338,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|||||||
# Check the number of returned files for save_vocabulary
|
# Check the number of returned files for save_vocabulary
|
||||||
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
||||||
|
|
||||||
|
# Check for padding
|
||||||
|
self.assertRaises(ValueError, self.assert_padding, tokenizer_r, tokenizer_p)
|
||||||
|
|
||||||
def test_distilbert(self):
|
def test_distilbert(self):
|
||||||
for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
||||||
tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name)
|
tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name)
|
||||||
@@ -317,6 +380,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|||||||
# Check the number of returned files for save_vocabulary
|
# Check the number of returned files for save_vocabulary
|
||||||
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
||||||
|
|
||||||
|
# Check for padding
|
||||||
|
self.assert_padding(tokenizer_r, tokenizer_p)
|
||||||
|
|
||||||
def test_gpt2(self):
|
def test_gpt2(self):
|
||||||
for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
||||||
tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name)
|
tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name)
|
||||||
@@ -355,6 +421,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|||||||
# Check the number of returned files for save_vocabulary
|
# Check the number of returned files for save_vocabulary
|
||||||
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
||||||
|
|
||||||
|
# Check for padding
|
||||||
|
self.assertRaises(ValueError, self.assert_padding, tokenizer_r, tokenizer_p)
|
||||||
|
|
||||||
def test_roberta(self):
|
def test_roberta(self):
|
||||||
for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
||||||
tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name)
|
tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name)
|
||||||
@@ -393,6 +462,10 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|||||||
# Check the number of returned files for save_vocabulary
|
# Check the number of returned files for save_vocabulary
|
||||||
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
||||||
|
|
||||||
|
# Check for padding
|
||||||
|
# TODO: Re-enable this test as soon as Roberta align with the python tokenizer.
|
||||||
|
# self.assert_padding(tokenizer_r, tokenizer_p)
|
||||||
|
|
||||||
def test_openai(self):
|
def test_openai(self):
|
||||||
for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
||||||
tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name)
|
tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name)
|
||||||
@@ -431,6 +504,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|||||||
# Check the number of returned files for save_vocabulary
|
# Check the number of returned files for save_vocabulary
|
||||||
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
||||||
|
|
||||||
|
# Check for padding
|
||||||
|
self.assertRaises(ValueError, self.assert_padding, tokenizer_r, tokenizer_p)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user