Use LoggingLevel context manager in 3 tests (#28575)
* inside with LoggingLevel * remove is_flaky --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -43,8 +43,8 @@ from transformers.testing_utils import (
|
||||
TOKEN,
|
||||
USER,
|
||||
CaptureLogger,
|
||||
LoggingLevel,
|
||||
TestCasePlus,
|
||||
is_flaky,
|
||||
is_staging_test,
|
||||
require_accelerate,
|
||||
require_flax,
|
||||
@@ -290,16 +290,14 @@ class ModelUtilsTest(TestCasePlus):
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@is_flaky(
|
||||
description="Capturing logs is flaky: https://app.circleci.com/pipelines/github/huggingface/transformers/81004/workflows/4919e5c9-0ea2-457b-ad4f-65371f79e277/jobs/1038999"
|
||||
)
|
||||
def test_model_from_pretrained_with_different_pretrained_model_name(self):
|
||||
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
logger = logging.get_logger("transformers.configuration_utils")
|
||||
with CaptureLogger(logger) as cl:
|
||||
BertModel.from_pretrained(TINY_T5)
|
||||
with LoggingLevel(logging.WARNING):
|
||||
with CaptureLogger(logger) as cl:
|
||||
BertModel.from_pretrained(TINY_T5)
|
||||
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
|
||||
|
||||
def test_model_from_config_torch_dtype(self):
|
||||
@@ -1024,9 +1022,6 @@ class ModelUtilsTest(TestCasePlus):
|
||||
# Should only complain about the missing bias
|
||||
self.assertListEqual(load_info["missing_keys"], ["decoder.bias"])
|
||||
|
||||
@is_flaky(
|
||||
description="Capturing logs is flaky: https://app.circleci.com/pipelines/github/huggingface/transformers/81004/workflows/4919e5c9-0ea2-457b-ad4f-65371f79e277/jobs/1038999"
|
||||
)
|
||||
def test_unexpected_keys_warnings(self):
|
||||
model = ModelWithHead(PretrainedConfig())
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
@@ -1034,8 +1029,9 @@ class ModelUtilsTest(TestCasePlus):
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
# Loading the model with a new class, we don't get a warning for unexpected weights, just an info
|
||||
with CaptureLogger(logger) as cl:
|
||||
_, loading_info = BaseModel.from_pretrained(tmp_dir, output_loading_info=True)
|
||||
with LoggingLevel(logging.WARNING):
|
||||
with CaptureLogger(logger) as cl:
|
||||
_, loading_info = BaseModel.from_pretrained(tmp_dir, output_loading_info=True)
|
||||
self.assertNotIn("were not used when initializing ModelWithHead", cl.out)
|
||||
self.assertEqual(
|
||||
set(loading_info["unexpected_keys"]),
|
||||
@@ -1046,8 +1042,9 @@ class ModelUtilsTest(TestCasePlus):
|
||||
state_dict = model.state_dict()
|
||||
state_dict["added_key"] = copy.deepcopy(state_dict["linear.weight"])
|
||||
safe_save_file(state_dict, os.path.join(tmp_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
||||
with CaptureLogger(logger) as cl:
|
||||
_, loading_info = ModelWithHead.from_pretrained(tmp_dir, output_loading_info=True)
|
||||
with LoggingLevel(logging.WARNING):
|
||||
with CaptureLogger(logger) as cl:
|
||||
_, loading_info = ModelWithHead.from_pretrained(tmp_dir, output_loading_info=True)
|
||||
self.assertIn("were not used when initializing ModelWithHead: ['added_key']", cl.out)
|
||||
self.assertEqual(loading_info["unexpected_keys"], ["added_key"])
|
||||
|
||||
@@ -1056,75 +1053,82 @@ class ModelUtilsTest(TestCasePlus):
|
||||
|
||||
with self.subTest("Ensure no warnings when pad_token_id is None."):
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as cl:
|
||||
config_no_pad_token = PretrainedConfig()
|
||||
config_no_pad_token.pad_token_id = None
|
||||
model = ModelWithHead(config_no_pad_token)
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
with LoggingLevel(logging.WARNING):
|
||||
with CaptureLogger(logger) as cl:
|
||||
config_no_pad_token = PretrainedConfig()
|
||||
config_no_pad_token.pad_token_id = None
|
||||
model = ModelWithHead(config_no_pad_token)
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out)
|
||||
|
||||
with self.subTest("Ensure no warnings when there is an attention_mask."):
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as cl:
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
||||
with LoggingLevel(logging.WARNING):
|
||||
with CaptureLogger(logger) as cl:
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
||||
self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out)
|
||||
|
||||
with self.subTest("Ensure no warnings when there are no pad_token_ids in the input_ids."):
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as cl:
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[1, 345, 232, 328, 740, 140, 1695, 69, 6078, 2341, 25]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
with LoggingLevel(logging.WARNING):
|
||||
with CaptureLogger(logger) as cl:
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[1, 345, 232, 328, 740, 140, 1695, 69, 6078, 2341, 25]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out)
|
||||
|
||||
with self.subTest("Ensure a warning is shown when the input_ids start with a pad_token_id."):
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as cl:
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 432, 5232]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
with LoggingLevel(logging.WARNING):
|
||||
with CaptureLogger(logger) as cl:
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 432, 5232]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out)
|
||||
|
||||
with self.subTest("Ensure a warning is shown when the input_ids end with a pad_token_id."):
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as cl:
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[432, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
with LoggingLevel(logging.WARNING):
|
||||
with CaptureLogger(logger) as cl:
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[432, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out)
|
||||
|
||||
with self.subTest("Ensure that the warning is shown at most once."):
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as cl:
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
with LoggingLevel(logging.WARNING):
|
||||
with CaptureLogger(logger) as cl:
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
self.assertEqual(cl.out.count("We strongly recommend passing in an `attention_mask`"), 1)
|
||||
|
||||
with self.subTest("Ensure a different warning is shown when the pad_token_id is equal to the bos_token_id."):
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as cl:
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
config.bos_token_id = config.pad_token_id
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
with LoggingLevel(logging.WARNING):
|
||||
with CaptureLogger(logger) as cl:
|
||||
config = PretrainedConfig()
|
||||
config.pad_token_id = 0
|
||||
config.bos_token_id = config.pad_token_id
|
||||
model = ModelWithHead(config)
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
|
||||
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
|
||||
self.assertIn("You may ignore this warning if your `pad_token_id`", cl.out)
|
||||
|
||||
if not is_torchdynamo_available():
|
||||
|
||||
Reference in New Issue
Block a user