Show a warning for missing attention masks when pad_token_id is not None (#24510)
* Adding warning messages to BERT for missing attention masks These warning messages when there are pad tokens within the input ids and no attention masks are given. The warning message should only show up once. * Adding warning messages to BERT for missing attention masks These warning messages are shown when the pad_token_id is not None and no attention masks are given. The warning message should only show up once. * Ran fix copies to copy over the changes to some of the other models * Add logger.warning_once.cache_clear() to the test * Shows warning when there are no attention masks and input_ids start/end with pad tokens * Using warning_once() instead and fix indexing in input_ids check --------- Co-authored-by: JB Lau <hckyn@voyager2.local>
This commit is contained in:
@@ -938,6 +938,82 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertIn("were not used when initializing ModelWithHead: ['added_key']", cl.out)
|
||||
self.assertEqual(loading_info["unexpected_keys"], ["added_key"])
|
||||
|
||||
def test_warn_if_padding_and_no_attention_mask(self):
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
|
||||
with self.subTest("Ensure no warnings when pad_token_id is None."):
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as cl:
|
||||
config_no_pad_token = PretrainedConfig()
|
||||
config_no_pad_token.pad_token_id = None
|
||||
model = ModelWithHead(config_no_pad_token)
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out)
|
||||
|
||||
with self.subTest("Ensure no warnings when there is an attention_mask."):
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as cl:
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
||||
self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out)
|
||||
|
||||
with self.subTest("Ensure no warnings when there are no pad_token_ids in the input_ids."):
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as cl:
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[1, 345, 232, 328, 740, 140, 1695, 69, 6078, 2341, 25]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out)
|
||||
|
||||
with self.subTest("Ensure a warning is shown when the input_ids start with a pad_token_id."):
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as cl:
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 432, 5232]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out)
|
||||
|
||||
with self.subTest("Ensure a warning is shown when the input_ids end with a pad_token_id."):
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as cl:
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[432, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out)
|
||||
|
||||
with self.subTest("Ensure that the warning is shown at most once."):
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as cl:
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
self.assertEqual(cl.out.count("We strongly recommend passing in an `attention_mask`"), 1)
|
||||
|
||||
with self.subTest("Ensure a different warning is shown when the pad_token_id is equal to the bos_token_id."):
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as cl:
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
config.bos_token_id = config.pad_token_id
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
self.assertIn("You may ignore this warning if your `pad_token_id`", cl.out)
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_pretrained_low_mem_new_config(self):
|
||||
|
||||
Reference in New Issue
Block a user