🚨🚨🚨 An attempt to fix #29554. Include 'LayerNorm.' in gamma/beta rename scope, optimize string search. (#35615)

* An attempt to fix #29554. Include 'LayerNorm.' in gamma/beta rename scope, reduce number of characters searched on every load considerably.

* Fix fix on load issue

* Fix gamma/beta warning test

* A style complaint

* Improve efficiency of weight norm key rename. Add better comments about weight norm and layer norm renaming.

* Habitual elif redunant with the return
This commit is contained in:
Ross Wightman
2025-01-16 17:25:44 -08:00
committed by GitHub
parent 02a492a838
commit 8c1b5d3782
3 changed files with 60 additions and 65 deletions

View File

@@ -1618,57 +1618,47 @@ class ModelUtilsTest(TestCasePlus):
self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"]))
def test_warning_for_beta_gamma_parameters(self):
class TestModelGamma(PreTrainedModel):
class TestGammaBetaNorm(torch.nn.Module):
def __init__(self):
super().__init__()
self.gamma = torch.nn.Parameter(torch.ones(1))
self.beta = torch.nn.Parameter(torch.zeros(1))
def forward(self):
return self.gamma.sum() + self.beta.sum()
class TestModelGammaBeta(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.gamma_param = nn.Parameter(torch.ones(10))
self.LayerNorm = TestGammaBetaNorm()
self.post_init()
def forward(self):
return self.gamma_param.sum()
return self.LayerNorm()
logger = logging.get_logger("transformers.modeling_utils")
config = PretrainedConfig()
warning_msg_gamma = "`gamma_param` -> `weight_param`"
model = TestModelGamma(config)
warning_msg_gamma = "`LayerNorm.gamma` -> `LayerNorm.weight`"
warning_msg_beta = "`LayerNorm.beta` -> `LayerNorm.bias`"
model = TestModelGammaBeta(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with LoggingLevel(logging.INFO):
with CaptureLogger(logger) as cl1:
_, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True)
_, loading_info = TestModelGammaBeta.from_pretrained(
tmp_dir, config=config, output_loading_info=True
)
missing_keys = loading_info["missing_keys"]
unexpected_keys = loading_info["unexpected_keys"]
self.assertIn("`TestModelGamma`", cl1.out)
self.assertIn("`TestModelGammaBeta`", cl1.out)
self.assertIn(warning_msg_gamma, cl1.out)
self.assertIn("gamma_param", missing_keys)
self.assertIn("weight_param", unexpected_keys)
class TestModelBeta(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.beta_param = nn.Parameter(torch.ones(10))
self.post_init()
def forward(self):
return self.beta_param.sum()
warning_msg_beta = "`beta_param` -> `bias_param`"
model = TestModelBeta(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with LoggingLevel(logging.INFO):
with CaptureLogger(logger) as cl2:
_, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True)
missing_keys = loading_info["missing_keys"]
unexpected_keys = loading_info["unexpected_keys"]
self.assertIn("`TestModelBeta`", cl2.out)
self.assertIn(warning_msg_beta, cl2.out)
self.assertIn("beta_param", missing_keys)
self.assertIn("bias_param", unexpected_keys)
self.assertIn(warning_msg_beta, cl1.out)
self.assertIn("LayerNorm.gamma", missing_keys)
self.assertIn("LayerNorm.weight", unexpected_keys)
self.assertIn("LayerNorm.beta", missing_keys)
self.assertIn("LayerNorm.bias", unexpected_keys)
def test_isin_mps_friendly(self):
"""tests that our custom `isin_mps_friendly` matches `torch.isin`"""