fix #1532 and encode_plus
This commit is contained in:
@@ -223,7 +223,11 @@ class CommonTestCases:
|
||||
sequence = tokenizer.encode(seq_0, add_special_tokens=False)
|
||||
num_added_tokens = tokenizer.num_added_tokens()
|
||||
total_length = len(sequence) + num_added_tokens
|
||||
information = tokenizer.encode_plus(seq_0, max_length=total_length - 2, add_special_tokens=True, stride=stride)
|
||||
information = tokenizer.encode_plus(seq_0,
|
||||
max_length=total_length - 2,
|
||||
add_special_tokens=True,
|
||||
stride=stride,
|
||||
return_overflowing_tokens=True)
|
||||
|
||||
truncated_sequence = information["input_ids"]
|
||||
overflowing_tokens = information["overflowing_tokens"]
|
||||
@@ -250,10 +254,12 @@ class CommonTestCases:
|
||||
)
|
||||
|
||||
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True,
|
||||
stride=stride, truncation_strategy='only_second')
|
||||
stride=stride, truncation_strategy='only_second',
|
||||
return_overflowing_tokens=True)
|
||||
information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2,
|
||||
add_special_tokens=True, stride=stride,
|
||||
truncation_strategy='only_first')
|
||||
truncation_strategy='only_first',
|
||||
return_overflowing_tokens=True)
|
||||
|
||||
truncated_sequence = information["input_ids"]
|
||||
overflowing_tokens = information["overflowing_tokens"]
|
||||
@@ -285,7 +291,7 @@ class CommonTestCases:
|
||||
|
||||
# Testing single inputs
|
||||
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
|
||||
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True)
|
||||
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True, return_special_tokens_mask=True)
|
||||
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
||||
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
|
||||
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
|
||||
@@ -297,7 +303,8 @@ class CommonTestCases:
|
||||
# Testing inputs pairs
|
||||
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False) + tokenizer.encode(sequence_1,
|
||||
add_special_tokens=False)
|
||||
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, sequence_1, add_special_tokens=True)
|
||||
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, sequence_1, add_special_tokens=True,
|
||||
return_special_tokens_mask=True)
|
||||
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
||||
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
|
||||
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
|
||||
@@ -309,7 +316,9 @@ class CommonTestCases:
|
||||
# Testing with already existing special tokens
|
||||
if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id:
|
||||
tokenizer.add_special_tokens({'cls_token': '</s>', 'sep_token': '<s>'})
|
||||
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True)
|
||||
encoded_sequence_dict = tokenizer.encode_plus(sequence_0,
|
||||
add_special_tokens=True,
|
||||
return_special_tokens_mask=True)
|
||||
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
||||
special_tokens_mask_orig = encoded_sequence_dict["special_tokens_mask"]
|
||||
special_tokens_mask = tokenizer.get_special_tokens_mask(encoded_sequence_w_special, already_has_special_tokens=True)
|
||||
|
||||
Reference in New Issue
Block a user