Warning about too long input for fast tokenizers too (#8799)
* Warning about too long input for fast tokenizers too If truncation is not set in tokenizers, but the tokenization is too long for the model (`model_max_length`), we used to trigger a warning that The input would probably fail (which it most likely will). This PR re-enables the warning for fast tokenizers too and uses common code for the trigger to make sure it's consistent across. * Checking for pair of inputs too. * Making the function private and adding it's doc. * Remove formatting ?? in odd place. * Missed uppercase.
This commit is contained in:
@@ -666,11 +666,28 @@ class TokenizerTesterMixin:
|
||||
self.assertEqual(len(output["input_ids"][0]), model_max_length)
|
||||
|
||||
# Simple with no truncation
|
||||
output = tokenizer(seq_1, padding=padding_state, truncation=False)
|
||||
self.assertNotEqual(len(output["input_ids"]), model_max_length)
|
||||
# Reset warnings
|
||||
tokenizer.deprecation_warnings = {}
|
||||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||
output = tokenizer(seq_1, padding=padding_state, truncation=False)
|
||||
self.assertNotEqual(len(output["input_ids"]), model_max_length)
|
||||
self.assertEqual(len(cm.records), 1)
|
||||
self.assertTrue(
|
||||
cm.records[0].message.startswith(
|
||||
"Token indices sequence length is longer than the specified maximum sequence length for this model"
|
||||
)
|
||||
)
|
||||
|
||||
output = tokenizer([seq_1], padding=padding_state, truncation=False)
|
||||
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
|
||||
tokenizer.deprecation_warnings = {}
|
||||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||
output = tokenizer([seq_1], padding=padding_state, truncation=False)
|
||||
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
|
||||
self.assertEqual(len(cm.records), 1)
|
||||
self.assertTrue(
|
||||
cm.records[0].message.startswith(
|
||||
"Token indices sequence length is longer than the specified maximum sequence length for this model"
|
||||
)
|
||||
)
|
||||
|
||||
# Overflowing tokens
|
||||
stride = 2
|
||||
@@ -770,11 +787,28 @@ class TokenizerTesterMixin:
|
||||
self.assertEqual(len(output["input_ids"][0]), model_max_length)
|
||||
|
||||
# Simple with no truncation
|
||||
output = tokenizer(seq_1, seq_2, padding=padding_state, truncation=False)
|
||||
self.assertNotEqual(len(output["input_ids"]), model_max_length)
|
||||
# Reset warnings
|
||||
tokenizer.deprecation_warnings = {}
|
||||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||
output = tokenizer(seq_1, seq_2, padding=padding_state, truncation=False)
|
||||
self.assertNotEqual(len(output["input_ids"]), model_max_length)
|
||||
self.assertEqual(len(cm.records), 1)
|
||||
self.assertTrue(
|
||||
cm.records[0].message.startswith(
|
||||
"Token indices sequence length is longer than the specified maximum sequence length for this model"
|
||||
)
|
||||
)
|
||||
|
||||
output = tokenizer([seq_1], [seq_2], padding=padding_state, truncation=False)
|
||||
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
|
||||
tokenizer.deprecation_warnings = {}
|
||||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||
output = tokenizer([seq_1], [seq_2], padding=padding_state, truncation=False)
|
||||
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
|
||||
self.assertEqual(len(cm.records), 1)
|
||||
self.assertTrue(
|
||||
cm.records[0].message.startswith(
|
||||
"Token indices sequence length is longer than the specified maximum sequence length for this model"
|
||||
)
|
||||
)
|
||||
|
||||
truncated_first_sequence = tokenizer.encode(seq_0, add_special_tokens=False)[:-2] + tokenizer.encode(
|
||||
seq_1, add_special_tokens=False
|
||||
|
||||
Reference in New Issue
Block a user