Batch encore plus and overflowing tokens fails when non existing overflowing tokens for a sequence (#6677)

* Patch and test

* Fix tests
This commit is contained in:
Lysandre Debut
2020-09-09 12:55:17 +02:00
committed by GitHub
parent 9fd11bf1a8
commit 15478c1287
2 changed files with 17 additions and 3 deletions

View File

@@ -2440,6 +2440,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
# Truncation: Handle max sequence length # Truncation: Handle max sequence length
overflowing_tokens = []
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
ids, pair_ids, overflowing_tokens = self.truncate_sequences( ids, pair_ids, overflowing_tokens = self.truncate_sequences(
ids, ids,
@@ -2448,6 +2449,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
truncation_strategy=truncation_strategy, truncation_strategy=truncation_strategy,
stride=stride, stride=stride,
) )
if return_overflowing_tokens: if return_overflowing_tokens:
encoded_inputs["overflowing_tokens"] = overflowing_tokens encoded_inputs["overflowing_tokens"] = overflowing_tokens
encoded_inputs["num_truncated_tokens"] = total_len - max_length encoded_inputs["num_truncated_tokens"] = total_len - max_length

View File

@@ -1352,6 +1352,18 @@ class TokenizerTesterMixin:
self.assertEqual(input_dict, prepared_input_dict) self.assertEqual(input_dict, prepared_input_dict)
def test_batch_encode_plus_overflowing_tokens(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
string_sequences = ["Testing the prepare_for_model method.", "Test"]
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.batch_encode_plus(
string_sequences, return_overflowing_tokens=True, truncation=True, padding=True, max_length=3
)
@require_torch @require_torch
@require_tf @require_tf
def test_batch_encode_plus_tensors(self): def test_batch_encode_plus_tensors(self):