Add warning message for beta and gamma parameters (#31654)

* Add warning message for  and  parameters

* Fix when the warning is raised

* Formatting changes

* Improve testing and remove duplicated warning from _fix_key
This commit is contained in:
Omar Salman
2024-07-11 17:01:47 +05:00
committed by GitHub
parent 23d6d0cc06
commit 1499a55008
2 changed files with 57 additions and 0 deletions

View File

@@ -1511,6 +1511,57 @@ class ModelUtilsTest(TestCasePlus):
outputs_from_saved = new_model(input_ids)
self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"]))
def test_warning_for_beta_gamma_parameters(self):
class TestModelGamma(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.gamma_param = nn.Parameter(torch.ones(10))
self.post_init()
def forward(self):
return self.gamma_param.sum()
logger = logging.get_logger("transformers.modeling_utils")
config = PretrainedConfig()
warning_msg_gamma = "A parameter name that contains `gamma` will be renamed internally"
model = TestModelGamma(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with LoggingLevel(logging.WARNING):
with CaptureLogger(logger) as cl1:
_, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True)
missing_keys = loading_info["missing_keys"]
unexpected_keys = loading_info["unexpected_keys"]
self.assertIn(warning_msg_gamma, cl1.out)
self.assertIn("gamma_param", missing_keys)
self.assertIn("weight_param", unexpected_keys)
class TestModelBeta(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.beta_param = nn.Parameter(torch.ones(10))
self.post_init()
def forward(self):
return self.beta_param.sum()
warning_msg_beta = "A parameter name that contains `beta` will be renamed internally"
model = TestModelBeta(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with LoggingLevel(logging.WARNING):
with CaptureLogger(logger) as cl2:
_, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True)
missing_keys = loading_info["missing_keys"]
unexpected_keys = loading_info["unexpected_keys"]
self.assertIn(warning_msg_beta, cl2.out)
self.assertIn("beta_param", missing_keys)
self.assertIn("bias_param", unexpected_keys)
@slow
@require_torch