Use LoggingLevel context manager in 3 tests (#28575)

* inside with LoggingLevel

* remove is_flaky

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2024-01-18 14:41:25 +01:00
committed by GitHub
parent d2cdefb9ec
commit 0754217c82

View File

@@ -43,8 +43,8 @@ from transformers.testing_utils import (
TOKEN, TOKEN,
USER, USER,
CaptureLogger, CaptureLogger,
LoggingLevel,
TestCasePlus, TestCasePlus,
is_flaky,
is_staging_test, is_staging_test,
require_accelerate, require_accelerate,
require_flax, require_flax,
@@ -290,14 +290,12 @@ class ModelUtilsTest(TestCasePlus):
self.assertIsNotNone(model) self.assertIsNotNone(model)
@is_flaky(
description="Capturing logs is flaky: https://app.circleci.com/pipelines/github/huggingface/transformers/81004/workflows/4919e5c9-0ea2-457b-ad4f-65371f79e277/jobs/1038999"
)
def test_model_from_pretrained_with_different_pretrained_model_name(self): def test_model_from_pretrained_with_different_pretrained_model_name(self):
model = T5ForConditionalGeneration.from_pretrained(TINY_T5) model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
self.assertIsNotNone(model) self.assertIsNotNone(model)
logger = logging.get_logger("transformers.configuration_utils") logger = logging.get_logger("transformers.configuration_utils")
with LoggingLevel(logging.WARNING):
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
BertModel.from_pretrained(TINY_T5) BertModel.from_pretrained(TINY_T5)
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out) self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
@@ -1024,9 +1022,6 @@ class ModelUtilsTest(TestCasePlus):
# Should only complain about the missing bias # Should only complain about the missing bias
self.assertListEqual(load_info["missing_keys"], ["decoder.bias"]) self.assertListEqual(load_info["missing_keys"], ["decoder.bias"])
@is_flaky(
description="Capturing logs is flaky: https://app.circleci.com/pipelines/github/huggingface/transformers/81004/workflows/4919e5c9-0ea2-457b-ad4f-65371f79e277/jobs/1038999"
)
def test_unexpected_keys_warnings(self): def test_unexpected_keys_warnings(self):
model = ModelWithHead(PretrainedConfig()) model = ModelWithHead(PretrainedConfig())
logger = logging.get_logger("transformers.modeling_utils") logger = logging.get_logger("transformers.modeling_utils")
@@ -1034,6 +1029,7 @@ class ModelUtilsTest(TestCasePlus):
model.save_pretrained(tmp_dir) model.save_pretrained(tmp_dir)
# Loading the model with a new class, we don't get a warning for unexpected weights, just an info # Loading the model with a new class, we don't get a warning for unexpected weights, just an info
with LoggingLevel(logging.WARNING):
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
_, loading_info = BaseModel.from_pretrained(tmp_dir, output_loading_info=True) _, loading_info = BaseModel.from_pretrained(tmp_dir, output_loading_info=True)
self.assertNotIn("were not used when initializing ModelWithHead", cl.out) self.assertNotIn("were not used when initializing ModelWithHead", cl.out)
@@ -1046,6 +1042,7 @@ class ModelUtilsTest(TestCasePlus):
state_dict = model.state_dict() state_dict = model.state_dict()
state_dict["added_key"] = copy.deepcopy(state_dict["linear.weight"]) state_dict["added_key"] = copy.deepcopy(state_dict["linear.weight"])
safe_save_file(state_dict, os.path.join(tmp_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}) safe_save_file(state_dict, os.path.join(tmp_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
with LoggingLevel(logging.WARNING):
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
_, loading_info = ModelWithHead.from_pretrained(tmp_dir, output_loading_info=True) _, loading_info = ModelWithHead.from_pretrained(tmp_dir, output_loading_info=True)
self.assertIn("were not used when initializing ModelWithHead: ['added_key']", cl.out) self.assertIn("were not used when initializing ModelWithHead: ['added_key']", cl.out)
@@ -1056,6 +1053,7 @@ class ModelUtilsTest(TestCasePlus):
with self.subTest("Ensure no warnings when pad_token_id is None."): with self.subTest("Ensure no warnings when pad_token_id is None."):
logger.warning_once.cache_clear() logger.warning_once.cache_clear()
with LoggingLevel(logging.WARNING):
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
config_no_pad_token = PretrainedConfig() config_no_pad_token = PretrainedConfig()
config_no_pad_token.pad_token_id = None config_no_pad_token.pad_token_id = None
@@ -1066,6 +1064,7 @@ class ModelUtilsTest(TestCasePlus):
with self.subTest("Ensure no warnings when there is an attention_mask."): with self.subTest("Ensure no warnings when there is an attention_mask."):
logger.warning_once.cache_clear() logger.warning_once.cache_clear()
with LoggingLevel(logging.WARNING):
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
config = PretrainedConfig() config = PretrainedConfig()
config.pad_token_id = 0 config.pad_token_id = 0
@@ -1077,6 +1076,7 @@ class ModelUtilsTest(TestCasePlus):
with self.subTest("Ensure no warnings when there are no pad_token_ids in the input_ids."): with self.subTest("Ensure no warnings when there are no pad_token_ids in the input_ids."):
logger.warning_once.cache_clear() logger.warning_once.cache_clear()
with LoggingLevel(logging.WARNING):
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
config = PretrainedConfig() config = PretrainedConfig()
config.pad_token_id = 0 config.pad_token_id = 0
@@ -1087,6 +1087,7 @@ class ModelUtilsTest(TestCasePlus):
with self.subTest("Ensure a warning is shown when the input_ids start with a pad_token_id."): with self.subTest("Ensure a warning is shown when the input_ids start with a pad_token_id."):
logger.warning_once.cache_clear() logger.warning_once.cache_clear()
with LoggingLevel(logging.WARNING):
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
config = PretrainedConfig() config = PretrainedConfig()
config.pad_token_id = 0 config.pad_token_id = 0
@@ -1097,6 +1098,7 @@ class ModelUtilsTest(TestCasePlus):
with self.subTest("Ensure a warning is shown when the input_ids end with a pad_token_id."): with self.subTest("Ensure a warning is shown when the input_ids end with a pad_token_id."):
logger.warning_once.cache_clear() logger.warning_once.cache_clear()
with LoggingLevel(logging.WARNING):
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
config = PretrainedConfig() config = PretrainedConfig()
config.pad_token_id = 0 config.pad_token_id = 0
@@ -1107,6 +1109,7 @@ class ModelUtilsTest(TestCasePlus):
with self.subTest("Ensure that the warning is shown at most once."): with self.subTest("Ensure that the warning is shown at most once."):
logger.warning_once.cache_clear() logger.warning_once.cache_clear()
with LoggingLevel(logging.WARNING):
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
config = PretrainedConfig() config = PretrainedConfig()
config.pad_token_id = 0 config.pad_token_id = 0
@@ -1118,6 +1121,7 @@ class ModelUtilsTest(TestCasePlus):
with self.subTest("Ensure a different warning is shown when the pad_token_id is equal to the bos_token_id."): with self.subTest("Ensure a different warning is shown when the pad_token_id is equal to the bos_token_id."):
logger.warning_once.cache_clear() logger.warning_once.cache_clear()
with LoggingLevel(logging.WARNING):
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
config = PretrainedConfig() config = PretrainedConfig()
config.pad_token_id = 0 config.pad_token_id = 0