Warning about too long input for fast tokenizers too (#8799)

* Warning about too long input for fast tokenizers too

If truncation is not set in tokenizers, but the tokenization is too long
for the model (`model_max_length`), we used to trigger a warning that

The input would probably fail (which it most likely will).

This PR re-enables the warning for fast tokenizers too and uses common
code for the trigger to make sure it's consistent across.

* Checking for pair of inputs too.

* Making the function private and adding it's doc.

* Remove formatting ?? in odd place.

* Missed uppercase.
This commit is contained in:
Nicolas Patry
2020-12-02 16:18:28 +01:00
committed by GitHub
parent f6b44e6190
commit a8c3f9aa76
3 changed files with 67 additions and 16 deletions

View File

@@ -2866,14 +2866,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
# Check lengths
if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length and verbose:
if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
logger.warning(
"Token indices sequence length is longer than the specified maximum sequence length "
"for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors".format(len(encoded_inputs["input_ids"]), self.model_max_length)
)
self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
# Padding
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
@@ -3204,3 +3197,23 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
.replace(" 're", "'re")
)
return out_string
def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_length: Optional[int], verbose: bool):
"""
Depending on the input and internal state we might trigger a warning about a sequence that is too long for it's
corresponding model
Args:
ids (:obj:`List[str]`): The ids produced by the tokenization
max_length (:obj:`int`, `optional`): The max_length desired (does not trigger a warning if it is set)
verbose (:obj:`bool`): Whether or not to print more information and warnings.
"""
if max_length is None and len(ids) > self.model_max_length and verbose:
if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
logger.warning(
"Token indices sequence length is longer than the specified maximum sequence length "
"for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors".format(len(ids), self.model_max_length)
)
self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True