From a8c3f9aa760ed7b516ee00f602e8efc0e5d80285 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 2 Dec 2020 16:18:28 +0100 Subject: [PATCH] Warning about too long input for fast tokenizers too (#8799) * Warning about too long input for fast tokenizers too If truncation is not set in tokenizers, but the tokenization is too long for the model (`model_max_length`), we used to trigger a warning that The input would probably fail (which it most likely will). This PR re-enables the warning for fast tokenizers too and uses common code for the trigger to make sure it's consistent across. * Checking for pair of inputs too. * Making the function private and adding it's doc. * Remove formatting ?? in odd place. * Missed uppercase. --- src/transformers/tokenization_utils_base.py | 29 ++++++++---- src/transformers/tokenization_utils_fast.py | 4 ++ tests/test_tokenization_common.py | 50 +++++++++++++++++---- 3 files changed, 67 insertions(+), 16 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 9eb1a5ae40..d8e73217ee 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -2866,14 +2866,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): encoded_inputs["special_tokens_mask"] = [0] * len(sequence) # Check lengths - if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length and verbose: - if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False): - logger.warning( - "Token indices sequence length is longer than the specified maximum sequence length " - "for this model ({} > {}). Running this sequence through the model will result in " - "indexing errors".format(len(encoded_inputs["input_ids"]), self.model_max_length) - ) - self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True + self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose) # Padding if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: @@ -3204,3 +3197,23 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): .replace(" 're", "'re") ) return out_string + + def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_length: Optional[int], verbose: bool): + """ + Depending on the input and internal state we might trigger a warning about a sequence that is too long for it's + corresponding model + + Args: + ids (:obj:`List[str]`): The ids produced by the tokenization + max_length (:obj:`int`, `optional`): The max_length desired (does not trigger a warning if it is set) + verbose (:obj:`bool`): Whether or not to print more information and warnings. + + """ + if max_length is None and len(ids) > self.model_max_length and verbose: + if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False): + logger.warning( + "Token indices sequence length is longer than the specified maximum sequence length " + "for this model ({} > {}). Running this sequence through the model will result in " + "indexing errors".format(len(ids), self.model_max_length) + ) + self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True diff --git a/src/transformers/tokenization_utils_fast.py b/src/transformers/tokenization_utils_fast.py index c447a0b7f7..92388507d2 100644 --- a/src/transformers/tokenization_utils_fast.py +++ b/src/transformers/tokenization_utils_fast.py @@ -418,6 +418,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): overflow_to_sample_mapping += [i] * len(toks["input_ids"]) sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping + for input_ids in sanitized_tokens["input_ids"]: + self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose) return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors) def _encode_plus( @@ -474,6 +476,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): batched_output.encodings, ) + self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose) + return batched_output def convert_tokens_to_string(self, tokens: List[str]) -> str: diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 1bfd54c3fe..0095b9e243 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -666,11 +666,28 @@ class TokenizerTesterMixin: self.assertEqual(len(output["input_ids"][0]), model_max_length) # Simple with no truncation - output = tokenizer(seq_1, padding=padding_state, truncation=False) - self.assertNotEqual(len(output["input_ids"]), model_max_length) + # Reset warnings + tokenizer.deprecation_warnings = {} + with self.assertLogs("transformers", level="WARNING") as cm: + output = tokenizer(seq_1, padding=padding_state, truncation=False) + self.assertNotEqual(len(output["input_ids"]), model_max_length) + self.assertEqual(len(cm.records), 1) + self.assertTrue( + cm.records[0].message.startswith( + "Token indices sequence length is longer than the specified maximum sequence length for this model" + ) + ) - output = tokenizer([seq_1], padding=padding_state, truncation=False) - self.assertNotEqual(len(output["input_ids"][0]), model_max_length) + tokenizer.deprecation_warnings = {} + with self.assertLogs("transformers", level="WARNING") as cm: + output = tokenizer([seq_1], padding=padding_state, truncation=False) + self.assertNotEqual(len(output["input_ids"][0]), model_max_length) + self.assertEqual(len(cm.records), 1) + self.assertTrue( + cm.records[0].message.startswith( + "Token indices sequence length is longer than the specified maximum sequence length for this model" + ) + ) # Overflowing tokens stride = 2 @@ -770,11 +787,28 @@ class TokenizerTesterMixin: self.assertEqual(len(output["input_ids"][0]), model_max_length) # Simple with no truncation - output = tokenizer(seq_1, seq_2, padding=padding_state, truncation=False) - self.assertNotEqual(len(output["input_ids"]), model_max_length) + # Reset warnings + tokenizer.deprecation_warnings = {} + with self.assertLogs("transformers", level="WARNING") as cm: + output = tokenizer(seq_1, seq_2, padding=padding_state, truncation=False) + self.assertNotEqual(len(output["input_ids"]), model_max_length) + self.assertEqual(len(cm.records), 1) + self.assertTrue( + cm.records[0].message.startswith( + "Token indices sequence length is longer than the specified maximum sequence length for this model" + ) + ) - output = tokenizer([seq_1], [seq_2], padding=padding_state, truncation=False) - self.assertNotEqual(len(output["input_ids"][0]), model_max_length) + tokenizer.deprecation_warnings = {} + with self.assertLogs("transformers", level="WARNING") as cm: + output = tokenizer([seq_1], [seq_2], padding=padding_state, truncation=False) + self.assertNotEqual(len(output["input_ids"][0]), model_max_length) + self.assertEqual(len(cm.records), 1) + self.assertTrue( + cm.records[0].message.startswith( + "Token indices sequence length is longer than the specified maximum sequence length for this model" + ) + ) truncated_first_sequence = tokenizer.encode(seq_0, add_special_tokens=False)[:-2] + tokenizer.encode( seq_1, add_special_tokens=False