Number of added tokens calculator
This commit is contained in:
@@ -198,3 +198,16 @@ class CommonTestCases:
|
|||||||
seq_1 = "With these inputs."
|
seq_1 = "With these inputs."
|
||||||
sequences, mask = tokenizer.encode(seq_0, seq_1, add_special_tokens=True, output_mask=True)
|
sequences, mask = tokenizer.encode(seq_0, seq_1, add_special_tokens=True, output_mask=True)
|
||||||
assert len(sequences) == len(mask)
|
assert len(sequences) == len(mask)
|
||||||
|
|
||||||
|
def test_number_of_added_tokens(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
seq_0 = "Test this method."
|
||||||
|
seq_1 = "With these inputs."
|
||||||
|
|
||||||
|
sequences = tokenizer.encode(seq_0, seq_1)
|
||||||
|
attached_sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
|
||||||
|
|
||||||
|
# Method is implemented (e.g. not GPT-2)
|
||||||
|
if len(attached_sequences) != 2:
|
||||||
|
assert tokenizer.num_added_tokens(pair=True) == len(attached_sequences) - sum([len(seq) for seq in sequences])
|
||||||
|
|||||||
@@ -518,6 +518,36 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
return len(to_add_tokens)
|
return len(to_add_tokens)
|
||||||
|
|
||||||
|
def num_added_tokens(self, pair=False):
|
||||||
|
"""
|
||||||
|
Returns the number of added tokens when encoding a sequence with special tokens.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this
|
||||||
|
inside your training loop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the
|
||||||
|
number of added tokens in the case of a single sequence if set to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tokens added to sequences
|
||||||
|
"""
|
||||||
|
|
||||||
|
if pair:
|
||||||
|
initial_tokens_len = sum([len(encoded) for encoded in self.encode("This is a sequence", "This is another")])
|
||||||
|
final_tokens = self.encode("This is a sequence", "This is another", add_special_tokens=True)
|
||||||
|
|
||||||
|
# In some models (e.g. GPT-2), there is no sequence pair encoding.
|
||||||
|
if len(final_tokens) == 2:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
final_tokens_len = len(final_tokens)
|
||||||
|
else:
|
||||||
|
initial_tokens_len = len(self.encode("This is a sequence"))
|
||||||
|
final_tokens_len = len(self.encode("This is a sequence", add_special_tokens=True))
|
||||||
|
|
||||||
|
return final_tokens_len - initial_tokens_len
|
||||||
|
|
||||||
def add_special_tokens(self, special_tokens_dict):
|
def add_special_tokens(self, special_tokens_dict):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user