always_truncate by default

This commit is contained in:
LysandreJik
2019-09-30 17:27:40 -04:00
parent 5ed50a93fb
commit 651bfb7ad5
2 changed files with 5 additions and 39 deletions

View File

@@ -232,23 +232,6 @@ class CommonTestCases:
assert len(truncated_sequence) == total_length - 2
assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2])
def test_always_truncate(self):
tokenizer = self.get_tokenizer()
seq_0 = "This is a sentence to be encoded."
length_single_sequence = len(tokenizer.encode(seq_0))
length = len(tokenizer.encode(seq_0, seq_0, add_special_tokens=True))
not_truncated = tokenizer.encode(seq_0, seq_0, add_special_tokens=True, max_length=length_single_sequence)
truncated = tokenizer.encode(
seq_0, seq_0,
max_length=length_single_sequence,
add_special_tokens=True,
always_truncate=True
)
assert truncated == not_truncated[:length_single_sequence - length]
def test_maximum_encoding_length_pair_input(self):
tokenizer = self.get_tokenizer()
@@ -329,7 +312,6 @@ class CommonTestCases:
sequence_ids_orig = encoded_sequence_dict["sequence_ids"]
sequence_ids = tokenizer.get_sequence_ids(encoded_sequence_w_special, special_tokens_present=True)
assert len(sequence_ids) == len(encoded_sequence_w_special)
print(sequence_ids_orig, sequence_ids)
assert sequence_ids_orig == sequence_ids