Reduce the error log when using core models that need their weights renamed, and provide a step forward (#32656)
* Fin * Modify msg * Finish up nits
This commit is contained in:
@@ -105,7 +105,6 @@ from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
|
||||
|
||||
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
|
||||
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
|
||||
PARAM_RENAME_WARNING = "A parameter name that contains `{}` will be renamed internally to `{}`. Please use a different name to suppress this warning."
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -693,17 +692,30 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
renamed_keys = {}
|
||||
renamed_gamma = {}
|
||||
renamed_beta = {}
|
||||
warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` "
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if "gamma" in key:
|
||||
logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight"))
|
||||
# We add only the first key as an example
|
||||
new_key = key.replace("gamma", "weight")
|
||||
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
|
||||
if "beta" in key:
|
||||
logger.warning(PARAM_RENAME_WARNING.format("beta", "bias"))
|
||||
# We add only the first key as an example
|
||||
new_key = key.replace("beta", "bias")
|
||||
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
renamed_keys = {**renamed_gamma, **renamed_beta}
|
||||
if renamed_keys:
|
||||
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
|
||||
for old_key, new_key in renamed_keys.items():
|
||||
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
|
||||
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
|
||||
logger.info_once(warning_msg)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
@@ -819,6 +831,7 @@ def _load_state_dict_into_meta_model(
|
||||
is_safetensors=False,
|
||||
keep_in_fp32_modules=None,
|
||||
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
|
||||
pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys
|
||||
):
|
||||
"""
|
||||
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
||||
@@ -841,18 +854,30 @@ def _load_state_dict_into_meta_model(
|
||||
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
renamed_gamma = {}
|
||||
renamed_beta = {}
|
||||
is_quantized = hf_quantizer is not None
|
||||
warning_msg = f"This model {type(model)}"
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if "gamma" in key:
|
||||
logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight"))
|
||||
# We add only the first key as an example
|
||||
new_key = key.replace("gamma", "weight")
|
||||
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
|
||||
if "beta" in key:
|
||||
logger.warning(PARAM_RENAME_WARNING.format("beta", "bias"))
|
||||
# We add only the first key as an example
|
||||
new_key = key.replace("beta", "bias")
|
||||
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
renamed_keys = {**renamed_gamma, **renamed_beta}
|
||||
if renamed_keys:
|
||||
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
|
||||
for old_key, new_key in renamed_keys.items():
|
||||
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
|
||||
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
|
||||
logger.info_once(warning_msg)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
@@ -4535,7 +4560,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
@staticmethod
|
||||
def _load_pretrained_model_low_mem(
|
||||
model, loaded_state_dict_keys, resolved_archive_file, start_prefix="", hf_quantizer=None
|
||||
model,
|
||||
loaded_state_dict_keys,
|
||||
resolved_archive_file,
|
||||
start_prefix="",
|
||||
hf_quantizer=None,
|
||||
pretrained_model_name_or_path=None,
|
||||
):
|
||||
"""
|
||||
This is an experimental function that loads the model using ~1.x model size CPU memory
|
||||
|
||||
@@ -331,6 +331,21 @@ def warning_once(self, *args, **kwargs):
|
||||
logging.Logger.warning_once = warning_once
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def info_once(self, *args, **kwargs):
|
||||
"""
|
||||
This method is identical to `logger.info()`, but will emit the info with the same message only once
|
||||
|
||||
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
|
||||
The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
|
||||
another type of cache that includes the caller frame information in the hashing function.
|
||||
"""
|
||||
self.info(*args, **kwargs)
|
||||
|
||||
|
||||
logging.Logger.info_once = info_once
|
||||
|
||||
|
||||
class EmptyTqdm:
|
||||
"""Dummy tqdm which doesn't do anything."""
|
||||
|
||||
|
||||
@@ -1640,17 +1640,18 @@ class ModelUtilsTest(TestCasePlus):
|
||||
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
config = PretrainedConfig()
|
||||
warning_msg_gamma = "A parameter name that contains `gamma` will be renamed internally"
|
||||
warning_msg_gamma = "`gamma_param` -> `weight_param`"
|
||||
model = TestModelGamma(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
with LoggingLevel(logging.WARNING):
|
||||
with LoggingLevel(logging.INFO):
|
||||
with CaptureLogger(logger) as cl1:
|
||||
_, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True)
|
||||
|
||||
missing_keys = loading_info["missing_keys"]
|
||||
unexpected_keys = loading_info["unexpected_keys"]
|
||||
self.assertIn("`TestModelGamma`", cl1.out)
|
||||
self.assertIn(warning_msg_gamma, cl1.out)
|
||||
self.assertIn("gamma_param", missing_keys)
|
||||
self.assertIn("weight_param", unexpected_keys)
|
||||
@@ -1664,17 +1665,18 @@ class ModelUtilsTest(TestCasePlus):
|
||||
def forward(self):
|
||||
return self.beta_param.sum()
|
||||
|
||||
warning_msg_beta = "A parameter name that contains `beta` will be renamed internally"
|
||||
warning_msg_beta = "`beta_param` -> `bias_param`"
|
||||
model = TestModelBeta(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
with LoggingLevel(logging.WARNING):
|
||||
with LoggingLevel(logging.INFO):
|
||||
with CaptureLogger(logger) as cl2:
|
||||
_, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True)
|
||||
|
||||
missing_keys = loading_info["missing_keys"]
|
||||
unexpected_keys = loading_info["unexpected_keys"]
|
||||
self.assertIn("`TestModelBeta`", cl2.out)
|
||||
self.assertIn(warning_msg_beta, cl2.out)
|
||||
self.assertIn("beta_param", missing_keys)
|
||||
self.assertIn("bias_param", unexpected_keys)
|
||||
|
||||
Reference in New Issue
Block a user