Always truncate argument in the encode method

This commit is contained in:
LysandreJik
2019-09-30 10:20:14 -04:00
parent 7af0777910
commit 7c789c337d
2 changed files with 48 additions and 12 deletions

View File

@@ -232,6 +232,23 @@ class CommonTestCases:
assert len(truncated_sequence) == total_length - 2
assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2])
def test_always_truncate(self):
tokenizer = self.get_tokenizer()
seq_0 = "This is a sentence to be encoded."
length_single_sequence = len(tokenizer.encode(seq_0))
length = len(tokenizer.encode(seq_0, seq_0, add_special_tokens=True))
not_truncated = tokenizer.encode(seq_0, seq_0, add_special_tokens=True, max_length=length_single_sequence)
truncated = tokenizer.encode(
seq_0, seq_0,
max_length=length_single_sequence,
add_special_tokens=True,
always_truncate=True
)
assert truncated == not_truncated[:length_single_sequence - length]
def test_maximum_encoding_length_pair_input(self):
tokenizer = self.get_tokenizer()