Improve truncation_side (#14947)
* Enabling `truncation_side` for Slow and Fast tokenizer. Co-Authored-by: Niels Rogge <48327001+NielsRogge@users.noreply.github.com> * Disable failing tests. * Layout xlm. * assert -> assertEqual. Co-authored-by: Niels Rogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
@@ -1398,6 +1398,64 @@ class TokenizerTesterMixin:
|
||||
assert sequence_length == padded_sequence_left_length
|
||||
assert encoded_sequence == padded_sequence_left
|
||||
|
||||
def test_right_and_left_truncation(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
sequence = "This is a test sequence"
|
||||
|
||||
# RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
|
||||
truncation_size = 3
|
||||
tokenizer.truncation_side = "right"
|
||||
encoded_sequence = tokenizer.encode(sequence, add_special_tokens=False)
|
||||
sequence_length = len(encoded_sequence)
|
||||
# Remove EOS/BOS tokens
|
||||
truncated_sequence = tokenizer.encode(
|
||||
sequence, max_length=sequence_length - truncation_size, truncation=True, add_special_tokens=False
|
||||
)
|
||||
truncated_sequence_length = len(truncated_sequence)
|
||||
self.assertEqual(sequence_length, truncated_sequence_length + truncation_size)
|
||||
self.assertEqual(encoded_sequence[:-truncation_size], truncated_sequence)
|
||||
|
||||
# LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the truncation flag set to True
|
||||
tokenizer.truncation_side = "left"
|
||||
sequence_length = len(encoded_sequence)
|
||||
truncated_sequence = tokenizer.encode(
|
||||
sequence, max_length=sequence_length - truncation_size, truncation=True, add_special_tokens=False
|
||||
)
|
||||
truncated_sequence_length = len(truncated_sequence)
|
||||
self.assertEqual(sequence_length, truncated_sequence_length + truncation_size)
|
||||
self.assertEqual(encoded_sequence[truncation_size:], truncated_sequence)
|
||||
|
||||
# RIGHT & LEFT PADDING - Check that nothing is done for 'longest' and 'no_truncation'
|
||||
sequence_length = len(encoded_sequence)
|
||||
|
||||
tokenizer.truncation_side = "right"
|
||||
truncated_sequence_right = tokenizer.encode(sequence, truncation=True, add_special_tokens=False)
|
||||
truncated_sequence_right_length = len(truncated_sequence_right)
|
||||
self.assertEqual(sequence_length, truncated_sequence_right_length)
|
||||
self.assertEqual(encoded_sequence, truncated_sequence_right)
|
||||
|
||||
tokenizer.truncation_side = "left"
|
||||
truncated_sequence_left = tokenizer.encode(
|
||||
sequence, truncation="longest_first", add_special_tokens=False
|
||||
)
|
||||
truncated_sequence_left_length = len(truncated_sequence_left)
|
||||
self.assertEqual(sequence_length, truncated_sequence_left_length)
|
||||
self.assertEqual(encoded_sequence, truncated_sequence_left)
|
||||
|
||||
tokenizer.truncation_side = "right"
|
||||
truncated_sequence_right = tokenizer.encode(sequence, add_special_tokens=False)
|
||||
truncated_sequence_right_length = len(truncated_sequence_right)
|
||||
self.assertEqual(sequence_length, truncated_sequence_right_length)
|
||||
self.assertEqual(encoded_sequence, truncated_sequence_right)
|
||||
|
||||
tokenizer.truncation_side = "left"
|
||||
truncated_sequence_left = tokenizer.encode(sequence, truncation=False, add_special_tokens=False)
|
||||
truncated_sequence_left_length = len(truncated_sequence_left)
|
||||
self.assertEqual(sequence_length, truncated_sequence_left_length)
|
||||
self.assertEqual(encoded_sequence, truncated_sequence_left)
|
||||
|
||||
def test_padding_to_max_length(self):
|
||||
"""We keep this test for backward compatibility but it should be remove when `pad_to_max_length` is deprecated."""
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
|
||||
Reference in New Issue
Block a user