Add warning message for beta and gamma parameters (#31654)
* Add warning message for and parameters * Fix when the warning is raised * Formatting changes * Improve testing and remove duplicated warning from _fix_key
This commit is contained in:
@@ -104,6 +104,8 @@ from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
|
|||||||
|
|
||||||
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
|
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
|
||||||
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
|
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
|
||||||
|
PARAM_RENAME_WARNING = "A parameter name that contains `{}` will be renamed internally to `{}`. Please use a different name to suppress this warning."
|
||||||
|
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
|
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
|
||||||
@@ -662,8 +664,10 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
|||||||
for key in state_dict.keys():
|
for key in state_dict.keys():
|
||||||
new_key = None
|
new_key = None
|
||||||
if "gamma" in key:
|
if "gamma" in key:
|
||||||
|
logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight"))
|
||||||
new_key = key.replace("gamma", "weight")
|
new_key = key.replace("gamma", "weight")
|
||||||
if "beta" in key:
|
if "beta" in key:
|
||||||
|
logger.warning(PARAM_RENAME_WARNING.format("beta", "bias"))
|
||||||
new_key = key.replace("beta", "bias")
|
new_key = key.replace("beta", "bias")
|
||||||
if new_key:
|
if new_key:
|
||||||
old_keys.append(key)
|
old_keys.append(key)
|
||||||
@@ -807,8 +811,10 @@ def _load_state_dict_into_meta_model(
|
|||||||
for key in state_dict.keys():
|
for key in state_dict.keys():
|
||||||
new_key = None
|
new_key = None
|
||||||
if "gamma" in key:
|
if "gamma" in key:
|
||||||
|
logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight"))
|
||||||
new_key = key.replace("gamma", "weight")
|
new_key = key.replace("gamma", "weight")
|
||||||
if "beta" in key:
|
if "beta" in key:
|
||||||
|
logger.warning(PARAM_RENAME_WARNING.format("beta", "bias"))
|
||||||
new_key = key.replace("beta", "bias")
|
new_key = key.replace("beta", "bias")
|
||||||
if new_key:
|
if new_key:
|
||||||
old_keys.append(key)
|
old_keys.append(key)
|
||||||
|
|||||||
@@ -1511,6 +1511,57 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
outputs_from_saved = new_model(input_ids)
|
outputs_from_saved = new_model(input_ids)
|
||||||
self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"]))
|
self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"]))
|
||||||
|
|
||||||
|
def test_warning_for_beta_gamma_parameters(self):
|
||||||
|
class TestModelGamma(PreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.gamma_param = nn.Parameter(torch.ones(10))
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return self.gamma_param.sum()
|
||||||
|
|
||||||
|
logger = logging.get_logger("transformers.modeling_utils")
|
||||||
|
config = PretrainedConfig()
|
||||||
|
warning_msg_gamma = "A parameter name that contains `gamma` will be renamed internally"
|
||||||
|
model = TestModelGamma(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
with LoggingLevel(logging.WARNING):
|
||||||
|
with CaptureLogger(logger) as cl1:
|
||||||
|
_, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True)
|
||||||
|
|
||||||
|
missing_keys = loading_info["missing_keys"]
|
||||||
|
unexpected_keys = loading_info["unexpected_keys"]
|
||||||
|
self.assertIn(warning_msg_gamma, cl1.out)
|
||||||
|
self.assertIn("gamma_param", missing_keys)
|
||||||
|
self.assertIn("weight_param", unexpected_keys)
|
||||||
|
|
||||||
|
class TestModelBeta(PreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.beta_param = nn.Parameter(torch.ones(10))
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return self.beta_param.sum()
|
||||||
|
|
||||||
|
warning_msg_beta = "A parameter name that contains `beta` will be renamed internally"
|
||||||
|
model = TestModelBeta(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
with LoggingLevel(logging.WARNING):
|
||||||
|
with CaptureLogger(logger) as cl2:
|
||||||
|
_, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True)
|
||||||
|
|
||||||
|
missing_keys = loading_info["missing_keys"]
|
||||||
|
unexpected_keys = loading_info["unexpected_keys"]
|
||||||
|
self.assertIn(warning_msg_beta, cl2.out)
|
||||||
|
self.assertIn("beta_param", missing_keys)
|
||||||
|
self.assertIn("bias_param", unexpected_keys)
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
Reference in New Issue
Block a user