Reduce the error log when using core models that need their weights renamed, and provide a step forward (#32656)

* Fin

* Modify msg

* Finish up nits
This commit is contained in:
Zach Mueller
2024-08-16 13:05:57 -04:00
committed by GitHub
parent 1c36db697a
commit 8ec028aded
3 changed files with 57 additions and 10 deletions

View File

@@ -105,7 +105,6 @@ from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
PARAM_RENAME_WARNING = "A parameter name that contains `{}` will be renamed internally to `{}`. Please use a different name to suppress this warning."
if is_accelerate_available(): if is_accelerate_available():
@@ -693,17 +692,30 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_
# Convert old format to new format if needed from a PyTorch state_dict # Convert old format to new format if needed from a PyTorch state_dict
old_keys = [] old_keys = []
new_keys = [] new_keys = []
renamed_keys = {}
renamed_gamma = {}
renamed_beta = {}
warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` "
for key in state_dict.keys(): for key in state_dict.keys():
new_key = None new_key = None
if "gamma" in key: if "gamma" in key:
logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight")) # We add only the first key as an example
new_key = key.replace("gamma", "weight") new_key = key.replace("gamma", "weight")
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
if "beta" in key: if "beta" in key:
logger.warning(PARAM_RENAME_WARNING.format("beta", "bias")) # We add only the first key as an example
new_key = key.replace("beta", "bias") new_key = key.replace("beta", "bias")
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
if new_key: if new_key:
old_keys.append(key) old_keys.append(key)
new_keys.append(new_key) new_keys.append(new_key)
renamed_keys = {**renamed_gamma, **renamed_beta}
if renamed_keys:
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
for old_key, new_key in renamed_keys.items():
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
logger.info_once(warning_msg)
for old_key, new_key in zip(old_keys, new_keys): for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key) state_dict[new_key] = state_dict.pop(old_key)
@@ -819,6 +831,7 @@ def _load_state_dict_into_meta_model(
is_safetensors=False, is_safetensors=False,
keep_in_fp32_modules=None, keep_in_fp32_modules=None,
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys
): ):
""" """
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
@@ -841,18 +854,30 @@ def _load_state_dict_into_meta_model(
old_keys = [] old_keys = []
new_keys = [] new_keys = []
renamed_gamma = {}
renamed_beta = {}
is_quantized = hf_quantizer is not None is_quantized = hf_quantizer is not None
warning_msg = f"This model {type(model)}"
for key in state_dict.keys(): for key in state_dict.keys():
new_key = None new_key = None
if "gamma" in key: if "gamma" in key:
logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight")) # We add only the first key as an example
new_key = key.replace("gamma", "weight") new_key = key.replace("gamma", "weight")
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
if "beta" in key: if "beta" in key:
logger.warning(PARAM_RENAME_WARNING.format("beta", "bias")) # We add only the first key as an example
new_key = key.replace("beta", "bias") new_key = key.replace("beta", "bias")
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
if new_key: if new_key:
old_keys.append(key) old_keys.append(key)
new_keys.append(new_key) new_keys.append(new_key)
renamed_keys = {**renamed_gamma, **renamed_beta}
if renamed_keys:
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
for old_key, new_key in renamed_keys.items():
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
logger.info_once(warning_msg)
for old_key, new_key in zip(old_keys, new_keys): for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key) state_dict[new_key] = state_dict.pop(old_key)
@@ -4535,7 +4560,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
@staticmethod @staticmethod
def _load_pretrained_model_low_mem( def _load_pretrained_model_low_mem(
model, loaded_state_dict_keys, resolved_archive_file, start_prefix="", hf_quantizer=None model,
loaded_state_dict_keys,
resolved_archive_file,
start_prefix="",
hf_quantizer=None,
pretrained_model_name_or_path=None,
): ):
""" """
This is an experimental function that loads the model using ~1.x model size CPU memory This is an experimental function that loads the model using ~1.x model size CPU memory

View File

@@ -331,6 +331,21 @@ def warning_once(self, *args, **kwargs):
logging.Logger.warning_once = warning_once logging.Logger.warning_once = warning_once
@functools.lru_cache(None)
def info_once(self, *args, **kwargs):
"""
This method is identical to `logger.info()`, but will emit the info with the same message only once
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
another type of cache that includes the caller frame information in the hashing function.
"""
self.info(*args, **kwargs)
logging.Logger.info_once = info_once
class EmptyTqdm: class EmptyTqdm:
"""Dummy tqdm which doesn't do anything.""" """Dummy tqdm which doesn't do anything."""

View File

@@ -1640,17 +1640,18 @@ class ModelUtilsTest(TestCasePlus):
logger = logging.get_logger("transformers.modeling_utils") logger = logging.get_logger("transformers.modeling_utils")
config = PretrainedConfig() config = PretrainedConfig()
warning_msg_gamma = "A parameter name that contains `gamma` will be renamed internally" warning_msg_gamma = "`gamma_param` -> `weight_param`"
model = TestModelGamma(config) model = TestModelGamma(config)
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir) model.save_pretrained(tmp_dir)
with LoggingLevel(logging.WARNING): with LoggingLevel(logging.INFO):
with CaptureLogger(logger) as cl1: with CaptureLogger(logger) as cl1:
_, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True) _, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True)
missing_keys = loading_info["missing_keys"] missing_keys = loading_info["missing_keys"]
unexpected_keys = loading_info["unexpected_keys"] unexpected_keys = loading_info["unexpected_keys"]
self.assertIn("`TestModelGamma`", cl1.out)
self.assertIn(warning_msg_gamma, cl1.out) self.assertIn(warning_msg_gamma, cl1.out)
self.assertIn("gamma_param", missing_keys) self.assertIn("gamma_param", missing_keys)
self.assertIn("weight_param", unexpected_keys) self.assertIn("weight_param", unexpected_keys)
@@ -1664,17 +1665,18 @@ class ModelUtilsTest(TestCasePlus):
def forward(self): def forward(self):
return self.beta_param.sum() return self.beta_param.sum()
warning_msg_beta = "A parameter name that contains `beta` will be renamed internally" warning_msg_beta = "`beta_param` -> `bias_param`"
model = TestModelBeta(config) model = TestModelBeta(config)
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir) model.save_pretrained(tmp_dir)
with LoggingLevel(logging.WARNING): with LoggingLevel(logging.INFO):
with CaptureLogger(logger) as cl2: with CaptureLogger(logger) as cl2:
_, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True) _, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True)
missing_keys = loading_info["missing_keys"] missing_keys = loading_info["missing_keys"]
unexpected_keys = loading_info["unexpected_keys"] unexpected_keys = loading_info["unexpected_keys"]
self.assertIn("`TestModelBeta`", cl2.out)
self.assertIn(warning_msg_beta, cl2.out) self.assertIn(warning_msg_beta, cl2.out)
self.assertIn("beta_param", missing_keys) self.assertIn("beta_param", missing_keys)
self.assertIn("bias_param", unexpected_keys) self.assertIn("bias_param", unexpected_keys)