Show a warning for missing attention masks when pad_token_id is not None (#24510)

* Adding warning messages to BERT for missing attention masks

These warning messages when there are pad tokens within the input ids and
no attention masks are given. The warning message should only show up once.

* Adding warning messages to BERT for missing attention masks

These warning messages are shown when the pad_token_id is not None
and no attention masks are given. The warning message should only
show up once.

* Ran fix copies to copy over the changes to some of the other models

* Add logger.warning_once.cache_clear() to the test

* Shows warning when there are no attention masks and input_ids start/end with pad tokens

* Using warning_once() instead and fix indexing in input_ids check

---------

Co-authored-by: JB Lau <hckyn@voyager2.local>
This commit is contained in:
JB (Don)
2023-06-30 21:19:39 +09:00
committed by GitHub
parent fd8dcd0953
commit 78a2b19fc8
12 changed files with 140 additions and 1 deletions

View File

@@ -938,6 +938,82 @@ class ModelUtilsTest(TestCasePlus):
self.assertIn("were not used when initializing ModelWithHead: ['added_key']", cl.out)
self.assertEqual(loading_info["unexpected_keys"], ["added_key"])
def test_warn_if_padding_and_no_attention_mask(self):
logger = logging.get_logger("transformers.modeling_utils")
with self.subTest("Ensure no warnings when pad_token_id is None."):
logger.warning_once.cache_clear()
with CaptureLogger(logger) as cl:
config_no_pad_token = PretrainedConfig()
config_no_pad_token.pad_token_id = None
model = ModelWithHead(config_no_pad_token)
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out)
with self.subTest("Ensure no warnings when there is an attention_mask."):
logger.warning_once.cache_clear()
with CaptureLogger(logger) as cl:
config = PretrainedConfig()
config.pad_token_id = 0
model = ModelWithHead(config)
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out)
with self.subTest("Ensure no warnings when there are no pad_token_ids in the input_ids."):
logger.warning_once.cache_clear()
with CaptureLogger(logger) as cl:
config = PretrainedConfig()
config.pad_token_id = 0
model = ModelWithHead(config)
input_ids = torch.tensor([[1, 345, 232, 328, 740, 140, 1695, 69, 6078, 2341, 25]])
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out)
with self.subTest("Ensure a warning is shown when the input_ids start with a pad_token_id."):
logger.warning_once.cache_clear()
with CaptureLogger(logger) as cl:
config = PretrainedConfig()
config.pad_token_id = 0
model = ModelWithHead(config)
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 432, 5232]])
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out)
with self.subTest("Ensure a warning is shown when the input_ids end with a pad_token_id."):
logger.warning_once.cache_clear()
with CaptureLogger(logger) as cl:
config = PretrainedConfig()
config.pad_token_id = 0
model = ModelWithHead(config)
input_ids = torch.tensor([[432, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out)
with self.subTest("Ensure that the warning is shown at most once."):
logger.warning_once.cache_clear()
with CaptureLogger(logger) as cl:
config = PretrainedConfig()
config.pad_token_id = 0
model = ModelWithHead(config)
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
self.assertEqual(cl.out.count("We strongly recommend passing in an `attention_mask`"), 1)
with self.subTest("Ensure a different warning is shown when the pad_token_id is equal to the bos_token_id."):
logger.warning_once.cache_clear()
with CaptureLogger(logger) as cl:
config = PretrainedConfig()
config.pad_token_id = 0
config.bos_token_id = config.pad_token_id
model = ModelWithHead(config)
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
self.assertIn("You may ignore this warning if your `pad_token_id`", cl.out)
@require_torch_gpu
@slow
def test_pretrained_low_mem_new_config(self):