From e67676424191e5935362e5fe7e04b5c317d706a9 Mon Sep 17 00:00:00 2001 From: Funtowicz Morgan Date: Wed, 19 Feb 2020 22:09:51 +0100 Subject: [PATCH] Override build_inputs_with_special_tokens for fast tokenizers (#2912) * Override build_inputs_with_special_tokens for fast impl + unittest. Signed-off-by: Morgan Funtowicz * Quality + format. Signed-off-by: Morgan Funtowicz --- src/transformers/tokenization_bert.py | 8 ++++ src/transformers/tokenization_roberta.py | 7 ++++ src/transformers/tokenization_utils.py | 6 +++ tests/test_tokenization_fast.py | 47 ++++++++++++++++++++++++ 4 files changed, 68 insertions(+) diff --git a/src/transformers/tokenization_bert.py b/src/transformers/tokenization_bert.py index 63b439f952..834a610bce 100644 --- a/src/transformers/tokenization_bert.py +++ b/src/transformers/tokenization_bert.py @@ -572,3 +572,11 @@ class BertTokenizerFast(PreTrainedTokenizerFast): ) self.do_lower_case = do_lower_case + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + + if token_ids_1: + output += token_ids_1 + [self.sep_token_id] + + return output diff --git a/src/transformers/tokenization_roberta.py b/src/transformers/tokenization_roberta.py index 4f470c5dc4..ff2aa11004 100644 --- a/src/transformers/tokenization_roberta.py +++ b/src/transformers/tokenization_roberta.py @@ -210,3 +210,10 @@ class RobertaTokenizerFast(GPT2TokenizerFast): # We need to recompute max_len according to the newly register post_processor to get real values. self.max_len_single_sentence = self.max_len - self.num_added_tokens(False) # take into account special tokens self.max_len_sentences_pair = self.max_len - self.num_added_tokens(True) # take into account special tokens + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + if token_ids_1 is None: + return output + + return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index e85bd67aa3..761b1d29bb 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -1669,6 +1669,12 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): self._update_special_tokens() return added + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + if token_ids_1 is None: + return token_ids_0 + else: + return token_ids_0 + token_ids_1 + def num_added_tokens(self, pair=False): return self.tokenizer.num_special_tokens_to_add(pair) diff --git a/tests/test_tokenization_fast.py b/tests/test_tokenization_fast.py index fe4d89ba5d..fdcebd0117 100644 --- a/tests/test_tokenization_fast.py +++ b/tests/test_tokenization_fast.py @@ -172,6 +172,35 @@ class FastTokenizerMatchingTest(unittest.TestCase): self.assertEqual(len(tokens[key].shape), 2) self.assertEqual(tokens[key].shape[-1], 6) + def assert_build_inputs_with_special_tokens(self, tokenizer_r, tokenizer_p): + # Input string + input_simple = tokenizer_p.tokenize("This is a sample input") + input_pair = tokenizer_p.tokenize("This is a sample pair") + + # Generate output + output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple) + output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple) + self.assertEqual(output_p, output_r) + + # Generate pair output + output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple, input_pair) + output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair) + self.assertEqual(output_p, output_r) + + # Input tokens id + input_simple = tokenizer_p.encode("This is a sample input") + input_pair = tokenizer_p.encode("This is a sample pair") + + # Generate output + output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple) + output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple) + self.assertEqual(output_p, output_r) + + # Generate pair output + output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple, input_pair) + output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair) + self.assertEqual(output_p, output_r) + def test_bert(self): for tokenizer_name in BertTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): tokenizer_p = BertTokenizer.from_pretrained(tokenizer_name) @@ -204,6 +233,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): # Check for dynamic encoding sequence handling in batch_encode_plus self.assert_batch_encode_dynamic_overflowing(tokenizer_r) + # Check alignment for build_inputs_with_special_tokens + self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p) + @require_torch def test_transfoxl(self): for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys(): @@ -237,6 +269,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): # Check for dynamic encoding sequence handling in batch_encode_plus self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r) + # Check alignment for build_inputs_with_special_tokens + self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p) + def test_distilbert(self): for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name) @@ -270,6 +305,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): # Check for dynamic encoding sequence handling in batch_encode_plus self.assert_batch_encode_dynamic_overflowing(tokenizer_r) + # Check alignment for build_inputs_with_special_tokens + self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p) + def test_gpt2(self): for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys(): tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name) @@ -302,6 +340,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): # Check for dynamic encoding sequence handling in batch_encode_plus self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r) + # Check alignment for build_inputs_with_special_tokens + self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p) + def test_roberta(self): for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name) @@ -334,6 +375,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): # Check for dynamic encoding sequence handling in batch_encode_plus self.assert_batch_encode_dynamic_overflowing(tokenizer_r) + # Check alignment for build_inputs_with_special_tokens + self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p) + def test_openai(self): for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name) @@ -366,6 +410,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): # Check for dynamic encoding sequence handling in batch_encode_plus self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r) + # Check alignment for build_inputs_with_special_tokens + self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p) + if __name__ == "__main__": unittest.main()