encode and encode_plus handle attention masks and padding
This commit is contained in:
@@ -335,3 +335,54 @@ class CommonTestCases:
|
||||
special_tokens_mask = tokenizer.get_special_tokens_mask(encoded_sequence_w_special, already_has_special_tokens=True)
|
||||
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
|
||||
self.assertEqual(special_tokens_mask_orig, special_tokens_mask)
|
||||
|
||||
def test_padding_to_max_length(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
sequence = "Sequence"
|
||||
padding_size = 10
|
||||
padding_idx = tokenizer.pad_token_id
|
||||
|
||||
# Check that it correctly pads when a maximum length is specified along with the padding flag set to True
|
||||
encoded_sequence = tokenizer.encode(sequence)
|
||||
sequence_length = len(encoded_sequence)
|
||||
padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True)
|
||||
padded_sequence_length = len(padded_sequence)
|
||||
assert sequence_length + padding_size == padded_sequence_length
|
||||
assert encoded_sequence + [padding_idx] * padding_size == padded_sequence
|
||||
|
||||
# Check that nothing is done when a maximum length is not specified
|
||||
encoded_sequence = tokenizer.encode(sequence)
|
||||
sequence_length = len(encoded_sequence)
|
||||
padded_sequence = tokenizer.encode(sequence, pad_to_max_length=True)
|
||||
padded_sequence_length = len(padded_sequence)
|
||||
assert sequence_length == padded_sequence_length
|
||||
assert encoded_sequence == padded_sequence
|
||||
|
||||
def test_encode_plus_with_padding(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
sequence = "Sequence"
|
||||
padding_size = 10
|
||||
padding_idx = tokenizer.pad_token_id
|
||||
token_type_padding_idx = tokenizer.pad_token_type_id
|
||||
|
||||
encoded_sequence = tokenizer.encode_plus(sequence, return_special_tokens_mask=True)
|
||||
input_ids = encoded_sequence['input_ids']
|
||||
token_type_ids = encoded_sequence['token_type_ids']
|
||||
attention_mask = encoded_sequence['attention_mask']
|
||||
special_tokens_mask = encoded_sequence['special_tokens_mask']
|
||||
sequence_length = len(input_ids)
|
||||
|
||||
padded_sequence = tokenizer.encode_plus(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True, return_special_tokens_mask=True)
|
||||
padded_input_ids = padded_sequence['input_ids']
|
||||
padded_token_type_ids = padded_sequence['token_type_ids']
|
||||
padded_attention_mask = padded_sequence['attention_mask']
|
||||
padded_special_tokens_mask = padded_sequence['special_tokens_mask']
|
||||
padded_sequence_length = len(padded_input_ids)
|
||||
|
||||
assert sequence_length + padding_size == padded_sequence_length
|
||||
assert input_ids + [padding_idx] * padding_size == padded_input_ids
|
||||
assert token_type_ids + [token_type_padding_idx] * padding_size == padded_token_type_ids
|
||||
assert attention_mask + [0] * padding_size == padded_attention_mask
|
||||
assert special_tokens_mask + [1] * padding_size == padded_special_tokens_mask
|
||||
Reference in New Issue
Block a user