encode + encode_plus tests modified
This commit is contained in:
@@ -196,7 +196,8 @@ class CommonTestCases:
|
|||||||
if tokenizer.add_special_tokens_sentences_pair.__qualname__.split('.')[0] != "PreTrainedTokenizer":
|
if tokenizer.add_special_tokens_sentences_pair.__qualname__.split('.')[0] != "PreTrainedTokenizer":
|
||||||
seq_0 = "Test this method."
|
seq_0 = "Test this method."
|
||||||
seq_1 = "With these inputs."
|
seq_1 = "With these inputs."
|
||||||
sequences, mask = tokenizer.encode(seq_0, seq_1, add_special_tokens=True, output_mask=True)
|
information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True, output_mask=True)
|
||||||
|
sequences, mask = information["sequence"], information["mask"]
|
||||||
assert len(sequences) == len(mask)
|
assert len(sequences) == len(mask)
|
||||||
|
|
||||||
def test_number_of_added_tokens(self):
|
def test_number_of_added_tokens(self):
|
||||||
@@ -210,7 +211,7 @@ class CommonTestCases:
|
|||||||
|
|
||||||
# Method is implemented (e.g. not GPT-2)
|
# Method is implemented (e.g. not GPT-2)
|
||||||
if len(attached_sequences) != 2:
|
if len(attached_sequences) != 2:
|
||||||
assert tokenizer.num_added_tokens(pair=True) == len(attached_sequences) - sum([len(seq) for seq in sequences])
|
assert tokenizer.num_added_tokens(pair=True) == len(attached_sequences) - len(sequences)
|
||||||
|
|
||||||
def test_maximum_encoding_length_single_input(self):
|
def test_maximum_encoding_length_single_input(self):
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
@@ -220,8 +221,12 @@ class CommonTestCases:
|
|||||||
sequence = tokenizer.encode(seq_0)
|
sequence = tokenizer.encode(seq_0)
|
||||||
num_added_tokens = tokenizer.num_added_tokens()
|
num_added_tokens = tokenizer.num_added_tokens()
|
||||||
total_length = len(sequence) + num_added_tokens
|
total_length = len(sequence) + num_added_tokens
|
||||||
truncated_sequence = tokenizer.encode(seq_0, max_length=total_length - 2, add_special_tokens=True)
|
information = tokenizer.encode_plus(seq_0, max_length=total_length - 2, add_special_tokens=True)
|
||||||
|
|
||||||
|
truncated_sequence = information["sequence"]
|
||||||
|
overflowing_tokens = information["overflowing_tokens"]
|
||||||
|
|
||||||
|
assert len(overflowing_tokens) == 2
|
||||||
assert len(truncated_sequence) == total_length - 2
|
assert len(truncated_sequence) == total_length - 2
|
||||||
assert truncated_sequence == tokenizer.add_special_tokens_single_sentence(sequence[:-2])
|
assert truncated_sequence == tokenizer.add_special_tokens_single_sentence(sequence[:-2])
|
||||||
|
|
||||||
@@ -236,7 +241,10 @@ class CommonTestCases:
|
|||||||
tokenizer.encode(seq_0),
|
tokenizer.encode(seq_0),
|
||||||
tokenizer.encode(seq_1)[:-2]
|
tokenizer.encode(seq_1)[:-2]
|
||||||
)
|
)
|
||||||
truncated_sequence = tokenizer.encode(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True)
|
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True)
|
||||||
|
|
||||||
|
truncated_sequence = information["sequence"]
|
||||||
|
overflowing_tokens = information["overflowing_tokens"]
|
||||||
|
|
||||||
assert len(truncated_sequence) == len(sequence) - 2
|
assert len(truncated_sequence) == len(sequence) - 2
|
||||||
assert truncated_sequence == truncated_second_sequence
|
assert truncated_sequence == truncated_second_sequence
|
||||||
|
|||||||
Reference in New Issue
Block a user