Padding strategy (left and right) rather than boolean flag
This commit is contained in:
@@ -343,21 +343,33 @@ class CommonTestCases:
|
||||
padding_size = 10
|
||||
padding_idx = tokenizer.pad_token_id
|
||||
|
||||
# Check that it correctly pads when a maximum length is specified along with the padding flag set to True
|
||||
# RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
|
||||
encoded_sequence = tokenizer.encode(sequence)
|
||||
sequence_length = len(encoded_sequence)
|
||||
padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True)
|
||||
padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, padding_strategy='right')
|
||||
padded_sequence_length = len(padded_sequence)
|
||||
assert sequence_length + padding_size == padded_sequence_length
|
||||
assert encoded_sequence + [padding_idx] * padding_size == padded_sequence
|
||||
|
||||
# Check that nothing is done when a maximum length is not specified
|
||||
# LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
|
||||
encoded_sequence = tokenizer.encode(sequence)
|
||||
sequence_length = len(encoded_sequence)
|
||||
padded_sequence = tokenizer.encode(sequence, pad_to_max_length=True)
|
||||
padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, padding_strategy='left')
|
||||
padded_sequence_length = len(padded_sequence)
|
||||
assert sequence_length == padded_sequence_length
|
||||
assert encoded_sequence == padded_sequence
|
||||
assert sequence_length + padding_size == padded_sequence_length
|
||||
assert [padding_idx] * padding_size + encoded_sequence == padded_sequence
|
||||
|
||||
# RIGHT & LEFT PADDING - Check that nothing is done when a maximum length is not specified
|
||||
encoded_sequence = tokenizer.encode(sequence)
|
||||
sequence_length = len(encoded_sequence)
|
||||
padded_sequence_right = tokenizer.encode(sequence, padding_strategy='right')
|
||||
padded_sequence_right_length = len(padded_sequence_right)
|
||||
padded_sequence_left = tokenizer.encode(sequence, padding_strategy='left')
|
||||
padded_sequence_left_length = len(padded_sequence_left)
|
||||
assert sequence_length == padded_sequence_right_length
|
||||
assert encoded_sequence == padded_sequence_right
|
||||
assert sequence_length == padded_sequence_left_length
|
||||
assert encoded_sequence == padded_sequence_left
|
||||
|
||||
def test_encode_plus_with_padding(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
@@ -374,7 +386,8 @@ class CommonTestCases:
|
||||
special_tokens_mask = encoded_sequence['special_tokens_mask']
|
||||
sequence_length = len(input_ids)
|
||||
|
||||
padded_sequence = tokenizer.encode_plus(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True, return_special_tokens_mask=True)
|
||||
# Test right padding
|
||||
padded_sequence = tokenizer.encode_plus(sequence, max_length=sequence_length + padding_size, padding_strategy='right', return_special_tokens_mask=True)
|
||||
padded_input_ids = padded_sequence['input_ids']
|
||||
padded_token_type_ids = padded_sequence['token_type_ids']
|
||||
padded_attention_mask = padded_sequence['attention_mask']
|
||||
@@ -385,4 +398,18 @@ class CommonTestCases:
|
||||
assert input_ids + [padding_idx] * padding_size == padded_input_ids
|
||||
assert token_type_ids + [token_type_padding_idx] * padding_size == padded_token_type_ids
|
||||
assert attention_mask + [0] * padding_size == padded_attention_mask
|
||||
assert special_tokens_mask + [1] * padding_size == padded_special_tokens_mask
|
||||
assert special_tokens_mask + [1] * padding_size == padded_special_tokens_mask
|
||||
|
||||
# Test left padding
|
||||
padded_sequence = tokenizer.encode_plus(sequence, max_length=sequence_length + padding_size, padding_strategy='left', return_special_tokens_mask=True)
|
||||
padded_input_ids = padded_sequence['input_ids']
|
||||
padded_token_type_ids = padded_sequence['token_type_ids']
|
||||
padded_attention_mask = padded_sequence['attention_mask']
|
||||
padded_special_tokens_mask = padded_sequence['special_tokens_mask']
|
||||
padded_sequence_length = len(padded_input_ids)
|
||||
|
||||
assert sequence_length + padding_size == padded_sequence_length
|
||||
assert [padding_idx] * padding_size + input_ids == padded_input_ids
|
||||
assert [token_type_padding_idx] * padding_size + token_type_ids == padded_token_type_ids
|
||||
assert [0] * padding_size + attention_mask == padded_attention_mask
|
||||
assert [1] * padding_size + special_tokens_mask == padded_special_tokens_mask
|
||||
Reference in New Issue
Block a user