update encode_plus - add truncation strategies
This commit is contained in:
@@ -131,8 +131,8 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
text = tokenizer.encode("sequence builders")
|
||||
text_2 = tokenizer.encode("multi-sequence build")
|
||||
|
||||
encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
|
||||
encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
|
||||
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
|
||||
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
|
||||
|
||||
assert encoded_sentence == [101] + text + [102]
|
||||
assert encoded_pair == [101] + text + [102] + text_2 + [102]
|
||||
|
||||
@@ -36,8 +36,8 @@ class DistilBertTokenizationTest(BertTokenizationTest):
|
||||
text = tokenizer.encode("sequence builders")
|
||||
text_2 = tokenizer.encode("multi-sequence build")
|
||||
|
||||
encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
|
||||
encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
|
||||
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
|
||||
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
|
||||
|
||||
assert encoded_sentence == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id]
|
||||
assert encoded_pair == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] + \
|
||||
|
||||
@@ -87,8 +87,8 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True)
|
||||
encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True)
|
||||
|
||||
encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
|
||||
encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
|
||||
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
|
||||
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
|
||||
|
||||
assert encoded_sentence == encoded_text_from_decode
|
||||
assert encoded_pair == encoded_pair_from_decode
|
||||
|
||||
@@ -193,12 +193,12 @@ class CommonTestCases:
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
if tokenizer.add_special_tokens_sequence_pair.__qualname__.split('.')[0] != "PreTrainedTokenizer":
|
||||
if tokenizer.build_inputs_with_special_tokens.__qualname__.split('.')[0] != "PreTrainedTokenizer":
|
||||
seq_0 = "Test this method."
|
||||
seq_1 = "With these inputs."
|
||||
information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True)
|
||||
sequences, mask = information["input_ids"], information["token_type_ids"]
|
||||
assert len(sequences) == len(mask)
|
||||
self.assertEqual(len(sequences), len(mask))
|
||||
|
||||
def test_number_of_added_tokens(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
@@ -211,7 +211,7 @@ class CommonTestCases:
|
||||
|
||||
# Method is implemented (e.g. not GPT-2)
|
||||
if len(attached_sequences) != 2:
|
||||
assert tokenizer.num_added_tokens(pair=True) == len(attached_sequences) - len(sequences)
|
||||
self.assertEqual(tokenizer.num_added_tokens(pair=True), len(attached_sequences) - len(sequences))
|
||||
|
||||
def test_maximum_encoding_length_single_input(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
@@ -227,10 +227,10 @@ class CommonTestCases:
|
||||
truncated_sequence = information["input_ids"]
|
||||
overflowing_tokens = information["overflowing_tokens"]
|
||||
|
||||
assert len(overflowing_tokens) == 2 + stride
|
||||
assert overflowing_tokens == sequence[-(2 + stride):]
|
||||
assert len(truncated_sequence) == total_length - 2
|
||||
assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2])
|
||||
self.assertEqual(len(overflowing_tokens), 2 + stride)
|
||||
self.assertEqual(overflowing_tokens, sequence[-(2 + stride):])
|
||||
self.assertEqual(len(truncated_sequence), total_length - 2)
|
||||
self.assertEqual(truncated_sequence, tokenizer.build_inputs_with_special_tokens(sequence[:-2]))
|
||||
|
||||
def test_maximum_encoding_length_pair_input(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
@@ -243,7 +243,7 @@ class CommonTestCases:
|
||||
sequence_1_no_special_tokens = tokenizer.encode(seq_1)
|
||||
|
||||
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
|
||||
truncated_second_sequence = tokenizer.add_special_tokens_sequence_pair(
|
||||
truncated_second_sequence = tokenizer.build_inputs_with_special_tokens(
|
||||
tokenizer.encode(seq_0),
|
||||
tokenizer.encode(seq_1)[:-2]
|
||||
)
|
||||
@@ -258,11 +258,11 @@ class CommonTestCases:
|
||||
overflowing_tokens = information["overflowing_tokens"]
|
||||
overflowing_tokens_first_truncated = information_first_truncated["overflowing_tokens"]
|
||||
|
||||
assert len(overflowing_tokens) == 2 + stride
|
||||
assert overflowing_tokens == sequence_1_no_special_tokens[-(2 + stride):]
|
||||
assert overflowing_tokens_first_truncated == sequence_0_no_special_tokens[-(2 + stride):]
|
||||
assert len(truncated_sequence) == len(sequence) - 2
|
||||
assert truncated_sequence == truncated_second_sequence
|
||||
self.assertEqual(len(overflowing_tokens), 2 + stride)
|
||||
self.assertEqual(overflowing_tokens, sequence_1_no_special_tokens[-(2 + stride):])
|
||||
self.assertEqual(overflowing_tokens_first_truncated, sequence_0_no_special_tokens[-(2 + stride):])
|
||||
self.assertEqual(len(truncated_sequence), len(sequence) - 2)
|
||||
self.assertEqual(truncated_sequence, truncated_second_sequence)
|
||||
|
||||
def test_encode_input_type(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
@@ -273,8 +273,8 @@ class CommonTestCases:
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
formatted_input = tokenizer.encode(sequence, add_special_tokens=True)
|
||||
|
||||
assert tokenizer.encode(tokens, add_special_tokens=True) == formatted_input
|
||||
assert tokenizer.encode(input_ids, add_special_tokens=True) == formatted_input
|
||||
self.assertEqual(tokenizer.encode(tokens, add_special_tokens=True), formatted_input)
|
||||
self.assertEqual(tokenizer.encode(input_ids, add_special_tokens=True), formatted_input)
|
||||
|
||||
def test_special_tokens_mask(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
@@ -287,22 +287,22 @@ class CommonTestCases:
|
||||
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True)
|
||||
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
||||
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
|
||||
assert len(special_tokens_mask) == len(encoded_sequence_w_special)
|
||||
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
|
||||
|
||||
filtered_sequence = [(x if special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)]
|
||||
filtered_sequence = [(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)]
|
||||
filtered_sequence = [x for x in filtered_sequence if x is not None]
|
||||
assert encoded_sequence == filtered_sequence
|
||||
self.assertEqual(encoded_sequence, filtered_sequence)
|
||||
|
||||
# Testing inputs pairs
|
||||
encoded_sequence = tokenizer.encode(sequence_0) + tokenizer.encode(sequence_1)
|
||||
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, sequence_1, add_special_tokens=True)
|
||||
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
||||
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
|
||||
assert len(special_tokens_mask) == len(encoded_sequence_w_special)
|
||||
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
|
||||
|
||||
filtered_sequence = [(x if special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)]
|
||||
filtered_sequence = [(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)]
|
||||
filtered_sequence = [x for x in filtered_sequence if x is not None]
|
||||
assert encoded_sequence == filtered_sequence
|
||||
self.assertEqual(encoded_sequence, filtered_sequence)
|
||||
|
||||
# Testing with already existing special tokens
|
||||
if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id:
|
||||
@@ -310,9 +310,6 @@ class CommonTestCases:
|
||||
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True)
|
||||
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
||||
special_tokens_mask_orig = encoded_sequence_dict["special_tokens_mask"]
|
||||
special_tokens_mask = tokenizer.get_special_tokens_mask(encoded_sequence_w_special, special_tokens_present=True)
|
||||
assert len(special_tokens_mask) == len(encoded_sequence_w_special)
|
||||
assert special_tokens_mask_orig == special_tokens_mask
|
||||
|
||||
|
||||
|
||||
special_tokens_mask = tokenizer.get_special_tokens_mask(encoded_sequence_w_special, already_has_special_tokens=True)
|
||||
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
|
||||
self.assertEqual(special_tokens_mask_orig, special_tokens_mask)
|
||||
|
||||
@@ -72,8 +72,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
text = tokenizer.encode("sequence builders")
|
||||
text_2 = tokenizer.encode("multi-sequence build")
|
||||
|
||||
encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
|
||||
encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
|
||||
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
|
||||
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
|
||||
|
||||
assert encoded_sentence == [1] + text + [1]
|
||||
assert encoded_pair == [1] + text + [1] + text_2 + [1]
|
||||
|
||||
@@ -95,8 +95,8 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
text = tokenizer.encode("sequence builders")
|
||||
text_2 = tokenizer.encode("multi-sequence build")
|
||||
|
||||
encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
|
||||
encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
|
||||
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
|
||||
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
|
||||
|
||||
assert encoded_sentence == text + [4, 3]
|
||||
assert encoded_pair == text + [4] + text_2 + [4, 3]
|
||||
|
||||
Reference in New Issue
Block a user