Show a warning for missing attention masks when pad_token_id is not None (#24510)

* Adding warning messages to BERT for missing attention masks

These warning messages when there are pad tokens within the input ids and
no attention masks are given. The warning message should only show up once.

* Adding warning messages to BERT for missing attention masks

These warning messages are shown when the pad_token_id is not None
and no attention masks are given. The warning message should only
show up once.

* Ran fix copies to copy over the changes to some of the other models

* Add logger.warning_once.cache_clear() to the test

* Shows warning when there are no attention masks and input_ids start/end with pad tokens

* Using warning_once() instead and fix indexing in input_ids check

---------

Co-authored-by: JB Lau <hckyn@voyager2.local>
This commit is contained in:
JB (Don)
2023-06-30 21:19:39 +09:00
committed by GitHub
parent fd8dcd0953
commit 78a2b19fc8
12 changed files with 140 additions and 1 deletions

View File

@@ -18,7 +18,7 @@ import unittest
from transformers import BertConfig, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
from transformers.testing_utils import CaptureLogger, require_torch, require_torch_gpu, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@@ -40,6 +40,7 @@ if is_torch_available():
BertForTokenClassification,
BertLMHeadModel,
BertModel,
logging,
)
from transformers.models.bert.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
@@ -567,6 +568,29 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_warning_if_padding_and_no_attention_mask(self):
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = self.model_tester.prepare_config_and_inputs()
# Set pad tokens in the input_ids
input_ids[0, 0] = config.pad_token_id
# Check for warnings if the attention_mask is missing.
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
model = BertModel(config=config)
model.to(torch_device)
model.eval()
model(input_ids, attention_mask=None, token_type_ids=token_type_ids)
self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out)
@slow
def test_model_from_pretrained(self):
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: