From d572d7027b098072e090248f6c13d3cde40c926b Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 11 Sep 2019 11:20:07 +0200 Subject: [PATCH] Number of added tokens calculator --- .../tests/tokenization_tests_commons.py | 13 ++++++++ pytorch_transformers/tokenization_utils.py | 30 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index aa6746a758..7500741cee 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -198,3 +198,16 @@ class CommonTestCases: seq_1 = "With these inputs." sequences, mask = tokenizer.encode(seq_0, seq_1, add_special_tokens=True, output_mask=True) assert len(sequences) == len(mask) + + def test_number_of_added_tokens(self): + tokenizer = self.get_tokenizer() + + seq_0 = "Test this method." + seq_1 = "With these inputs." + + sequences = tokenizer.encode(seq_0, seq_1) + attached_sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=True) + + # Method is implemented (e.g. not GPT-2) + if len(attached_sequences) != 2: + assert tokenizer.num_added_tokens(pair=True) == len(attached_sequences) - sum([len(seq) for seq in sequences]) diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 6a4f8a2dc9..a22f15fa3e 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -518,6 +518,36 @@ class PreTrainedTokenizer(object): return len(to_add_tokens) + def num_added_tokens(self, pair=False): + """ + Returns the number of added tokens when encoding a sequence with special tokens. + + Note: + This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this + inside your training loop. + + Args: + pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the + number of added tokens in the case of a single sequence if set to False. + + Returns: + Number of tokens added to sequences + """ + + if pair: + initial_tokens_len = sum([len(encoded) for encoded in self.encode("This is a sequence", "This is another")]) + final_tokens = self.encode("This is a sequence", "This is another", add_special_tokens=True) + + # In some models (e.g. GPT-2), there is no sequence pair encoding. + if len(final_tokens) == 2: + return 0 + else: + final_tokens_len = len(final_tokens) + else: + initial_tokens_len = len(self.encode("This is a sequence")) + final_tokens_len = len(self.encode("This is a sequence", add_special_tokens=True)) + + return final_tokens_len - initial_tokens_len def add_special_tokens(self, special_tokens_dict): """