Padding strategy (left and right) rather than boolean flag

This commit is contained in:
LysandreJik
2019-11-21 11:30:40 -05:00
parent 9f374c8252
commit a7dafe2f41
2 changed files with 77 additions and 24 deletions

View File

@@ -343,21 +343,33 @@ class CommonTestCases:
padding_size = 10
padding_idx = tokenizer.pad_token_id
# Check that it correctly pads when a maximum length is specified along with the padding flag set to True
# RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
encoded_sequence = tokenizer.encode(sequence)
sequence_length = len(encoded_sequence)
padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True)
padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, padding_strategy='right')
padded_sequence_length = len(padded_sequence)
assert sequence_length + padding_size == padded_sequence_length
assert encoded_sequence + [padding_idx] * padding_size == padded_sequence
# Check that nothing is done when a maximum length is not specified
# LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
encoded_sequence = tokenizer.encode(sequence)
sequence_length = len(encoded_sequence)
padded_sequence = tokenizer.encode(sequence, pad_to_max_length=True)
padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, padding_strategy='left')
padded_sequence_length = len(padded_sequence)
assert sequence_length == padded_sequence_length
assert encoded_sequence == padded_sequence
assert sequence_length + padding_size == padded_sequence_length
assert [padding_idx] * padding_size + encoded_sequence == padded_sequence
# RIGHT & LEFT PADDING - Check that nothing is done when a maximum length is not specified
encoded_sequence = tokenizer.encode(sequence)
sequence_length = len(encoded_sequence)
padded_sequence_right = tokenizer.encode(sequence, padding_strategy='right')
padded_sequence_right_length = len(padded_sequence_right)
padded_sequence_left = tokenizer.encode(sequence, padding_strategy='left')
padded_sequence_left_length = len(padded_sequence_left)
assert sequence_length == padded_sequence_right_length
assert encoded_sequence == padded_sequence_right
assert sequence_length == padded_sequence_left_length
assert encoded_sequence == padded_sequence_left
def test_encode_plus_with_padding(self):
tokenizer = self.get_tokenizer()
@@ -374,7 +386,8 @@ class CommonTestCases:
special_tokens_mask = encoded_sequence['special_tokens_mask']
sequence_length = len(input_ids)
padded_sequence = tokenizer.encode_plus(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True, return_special_tokens_mask=True)
# Test right padding
padded_sequence = tokenizer.encode_plus(sequence, max_length=sequence_length + padding_size, padding_strategy='right', return_special_tokens_mask=True)
padded_input_ids = padded_sequence['input_ids']
padded_token_type_ids = padded_sequence['token_type_ids']
padded_attention_mask = padded_sequence['attention_mask']
@@ -385,4 +398,18 @@ class CommonTestCases:
assert input_ids + [padding_idx] * padding_size == padded_input_ids
assert token_type_ids + [token_type_padding_idx] * padding_size == padded_token_type_ids
assert attention_mask + [0] * padding_size == padded_attention_mask
assert special_tokens_mask + [1] * padding_size == padded_special_tokens_mask
assert special_tokens_mask + [1] * padding_size == padded_special_tokens_mask
# Test left padding
padded_sequence = tokenizer.encode_plus(sequence, max_length=sequence_length + padding_size, padding_strategy='left', return_special_tokens_mask=True)
padded_input_ids = padded_sequence['input_ids']
padded_token_type_ids = padded_sequence['token_type_ids']
padded_attention_mask = padded_sequence['attention_mask']
padded_special_tokens_mask = padded_sequence['special_tokens_mask']
padded_sequence_length = len(padded_input_ids)
assert sequence_length + padding_size == padded_sequence_length
assert [padding_idx] * padding_size + input_ids == padded_input_ids
assert [token_type_padding_idx] * padding_size + token_type_ids == padded_token_type_ids
assert [0] * padding_size + attention_mask == padded_attention_mask
assert [1] * padding_size + special_tokens_mask == padded_special_tokens_mask