Warning about too long input for fast tokenizers too (#8799)

* Warning about too long input for fast tokenizers too

If truncation is not set in tokenizers, but the tokenization is too long
for the model (`model_max_length`), we used to trigger a warning that

The input would probably fail (which it most likely will).

This PR re-enables the warning for fast tokenizers too and uses common
code for the trigger to make sure it's consistent across.

* Checking for pair of inputs too.

* Making the function private and adding it's doc.

* Remove formatting ?? in odd place.

* Missed uppercase.
This commit is contained in:
Nicolas Patry
2020-12-02 16:18:28 +01:00
committed by GitHub
parent f6b44e6190
commit a8c3f9aa76
3 changed files with 67 additions and 16 deletions

View File

@@ -666,11 +666,28 @@ class TokenizerTesterMixin:
self.assertEqual(len(output["input_ids"][0]), model_max_length)
# Simple with no truncation
output = tokenizer(seq_1, padding=padding_state, truncation=False)
self.assertNotEqual(len(output["input_ids"]), model_max_length)
# Reset warnings
tokenizer.deprecation_warnings = {}
with self.assertLogs("transformers", level="WARNING") as cm:
output = tokenizer(seq_1, padding=padding_state, truncation=False)
self.assertNotEqual(len(output["input_ids"]), model_max_length)
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
"Token indices sequence length is longer than the specified maximum sequence length for this model"
)
)
output = tokenizer([seq_1], padding=padding_state, truncation=False)
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
tokenizer.deprecation_warnings = {}
with self.assertLogs("transformers", level="WARNING") as cm:
output = tokenizer([seq_1], padding=padding_state, truncation=False)
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
"Token indices sequence length is longer than the specified maximum sequence length for this model"
)
)
# Overflowing tokens
stride = 2
@@ -770,11 +787,28 @@ class TokenizerTesterMixin:
self.assertEqual(len(output["input_ids"][0]), model_max_length)
# Simple with no truncation
output = tokenizer(seq_1, seq_2, padding=padding_state, truncation=False)
self.assertNotEqual(len(output["input_ids"]), model_max_length)
# Reset warnings
tokenizer.deprecation_warnings = {}
with self.assertLogs("transformers", level="WARNING") as cm:
output = tokenizer(seq_1, seq_2, padding=padding_state, truncation=False)
self.assertNotEqual(len(output["input_ids"]), model_max_length)
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
"Token indices sequence length is longer than the specified maximum sequence length for this model"
)
)
output = tokenizer([seq_1], [seq_2], padding=padding_state, truncation=False)
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
tokenizer.deprecation_warnings = {}
with self.assertLogs("transformers", level="WARNING") as cm:
output = tokenizer([seq_1], [seq_2], padding=padding_state, truncation=False)
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
"Token indices sequence length is longer than the specified maximum sequence length for this model"
)
)
truncated_first_sequence = tokenizer.encode(seq_0, add_special_tokens=False)[:-2] + tokenizer.encode(
seq_1, add_special_tokens=False