Batch encore plus and overflowing tokens fails when non existing overflowing tokens for a sequence (#6677)
* Patch and test * Fix tests
This commit is contained in:
@@ -2440,6 +2440,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||||||
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
|
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
|
||||||
|
|
||||||
# Truncation: Handle max sequence length
|
# Truncation: Handle max sequence length
|
||||||
|
overflowing_tokens = []
|
||||||
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
|
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
|
||||||
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
|
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
|
||||||
ids,
|
ids,
|
||||||
@@ -2448,9 +2449,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||||||
truncation_strategy=truncation_strategy,
|
truncation_strategy=truncation_strategy,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
)
|
)
|
||||||
if return_overflowing_tokens:
|
|
||||||
encoded_inputs["overflowing_tokens"] = overflowing_tokens
|
if return_overflowing_tokens:
|
||||||
encoded_inputs["num_truncated_tokens"] = total_len - max_length
|
encoded_inputs["overflowing_tokens"] = overflowing_tokens
|
||||||
|
encoded_inputs["num_truncated_tokens"] = total_len - max_length
|
||||||
|
|
||||||
# Add special tokens
|
# Add special tokens
|
||||||
if add_special_tokens:
|
if add_special_tokens:
|
||||||
|
|||||||
@@ -1352,6 +1352,18 @@ class TokenizerTesterMixin:
|
|||||||
|
|
||||||
self.assertEqual(input_dict, prepared_input_dict)
|
self.assertEqual(input_dict, prepared_input_dict)
|
||||||
|
|
||||||
|
def test_batch_encode_plus_overflowing_tokens(self):
|
||||||
|
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||||
|
for tokenizer in tokenizers:
|
||||||
|
string_sequences = ["Testing the prepare_for_model method.", "Test"]
|
||||||
|
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
|
||||||
|
tokenizer.batch_encode_plus(
|
||||||
|
string_sequences, return_overflowing_tokens=True, truncation=True, padding=True, max_length=3
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_batch_encode_plus_tensors(self):
|
def test_batch_encode_plus_tensors(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user