diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 80ab188055..879d1614b3 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -79,7 +79,7 @@ def truncate_and_pad( if pad_to_max_length and (pad_token and pad_token_id >= 0): tokenizer.enable_padding( - max_length=None, + max_length=max_length, direction=padding_side, pad_id=pad_token_id, pad_type_id=pad_token_type_id, diff --git a/tests/test_tokenization_fast.py b/tests/test_tokenization_fast.py index 18a7eac00a..31cd850a61 100644 --- a/tests/test_tokenization_fast.py +++ b/tests/test_tokenization_fast.py @@ -76,6 +76,63 @@ class FastTokenizerMatchingTest(unittest.TestCase): for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()): self.assert_sequence_almost_equals(input_p[key], input_r[key], threshold) + def assert_padding(self, tokenizer_r, tokenizer_p): + # Simple input + input_r = tokenizer_r.encode("This is a simple input", max_length=15, pad_to_max_length=True) + input_p = tokenizer_p.encode("This is a simple input", max_length=15, pad_to_max_length=True) + + self.assertSequenceEqual(input_r, input_p) + + # Simple input + input_r = tokenizer_r.encode_plus("This is a simple input", max_length=15, pad_to_max_length=True) + input_p = tokenizer_p.encode_plus("This is a simple input", max_length=15, pad_to_max_length=True) + + self.assertSequenceEqual(input_r, input_p) + + # Simple input + # TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding + # input_r = tokenizer_r.batch_encode_plus( + # ["This is a simple input 1", "This is a simple input 2"], max_length=15, pad_to_max_length=True + # ) + # input_p = tokenizer_p.batch_encode_plus( + # ["This is a simple input 1", "This is a simple input 2"], max_length=15, pad_to_max_length=True + # ) + + # self.assertSequenceEqual(input_r, input_p) + + # Pair input + input_r = tokenizer_r.encode("This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True) + input_p = tokenizer_p.encode("This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True) + + self.assertSequenceEqual(input_r, input_p) + + # Pair input + input_r = tokenizer_r.encode_plus( + "This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True + ) + input_p = tokenizer_p.encode_plus( + "This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True + ) + + self.assertSequenceEqual(input_r, input_p) + + # Pair input + # TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding + # input_r = tokenizer_r.batch_encode_plus( + # ["This is a simple input 1", "This is a simple input 2"], + # ["This is a simple pair 1", "This is a simple pair 2"], + # max_length=15, + # pad_to_max_length=True, + # ) + # input_p = tokenizer_p.batch_encode_plus( + # ["This is a simple input 1", "This is a simple input 2"], + # ["This is a simple pair 1", "This is a simple pair 2"], + # max_length=15, + # pad_to_max_length=True, + # ) + + # self.assertSequenceEqual(input_r, input_p) + def assert_add_tokens(self, tokenizer_r): vocab_size = tokenizer_r.vocab_size self.assertEqual(tokenizer_r.add_tokens(""), 0) @@ -239,6 +296,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): # Check the number of returned files for save_vocabulary self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary("."))) + # Check for padding + self.assert_padding(tokenizer_r, tokenizer_p) + @require_torch def test_transfoxl(self): for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys(): @@ -278,6 +338,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): # Check the number of returned files for save_vocabulary self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary("."))) + # Check for padding + self.assertRaises(ValueError, self.assert_padding, tokenizer_r, tokenizer_p) + def test_distilbert(self): for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name) @@ -317,6 +380,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): # Check the number of returned files for save_vocabulary self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary("."))) + # Check for padding + self.assert_padding(tokenizer_r, tokenizer_p) + def test_gpt2(self): for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys(): tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name) @@ -355,6 +421,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): # Check the number of returned files for save_vocabulary self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary("."))) + # Check for padding + self.assertRaises(ValueError, self.assert_padding, tokenizer_r, tokenizer_p) + def test_roberta(self): for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name) @@ -393,6 +462,10 @@ class FastTokenizerMatchingTest(unittest.TestCase): # Check the number of returned files for save_vocabulary self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary("."))) + # Check for padding + # TODO: Re-enable this test as soon as Roberta align with the python tokenizer. + # self.assert_padding(tokenizer_r, tokenizer_p) + def test_openai(self): for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name) @@ -431,6 +504,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): # Check the number of returned files for save_vocabulary self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary("."))) + # Check for padding + self.assertRaises(ValueError, self.assert_padding, tokenizer_r, tokenizer_p) + if __name__ == "__main__": unittest.main()