Max encoding length + corresponding tests

This commit is contained in:
LysandreJik
2019-09-11 18:20:24 +02:00
parent c4d4f3ec8c
commit af23b626c8
2 changed files with 48 additions and 3 deletions

View File

@@ -211,3 +211,32 @@ class CommonTestCases:
# Method is implemented (e.g. not GPT-2)
if len(attached_sequences) != 2:
assert tokenizer.num_added_tokens(pair=True) == len(attached_sequences) - sum([len(seq) for seq in sequences])
def test_maximum_encoding_length_single_input(self):
tokenizer = self.get_tokenizer()
seq_0 = "This is a sentence to be encoded."
sequence = tokenizer.encode(seq_0)
num_added_tokens = tokenizer.num_added_tokens()
total_length = len(sequence) + num_added_tokens
truncated_sequence = tokenizer.encode(seq_0, max_length=total_length - 2, add_special_tokens=True)
assert len(truncated_sequence) == total_length - 2
assert truncated_sequence == tokenizer.add_special_tokens_single_sentence(sequence[:-2])
def test_maximum_encoding_length_pair_input(self):
tokenizer = self.get_tokenizer()
seq_0 = "This is a sentence to be encoded."
seq_1 = "This is another sentence to be encoded."
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
truncated_second_sequence = tokenizer.add_special_tokens_sentences_pair(
tokenizer.encode(seq_0),
tokenizer.encode(seq_1)[:-2]
)
truncated_sequence = tokenizer.encode(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True)
assert len(truncated_sequence) == len(sequence) - 2
assert truncated_sequence == truncated_second_sequence