Avoid unnecessary warnings when loading pretrained model (#5922)
* Avoid unnecessary warnings when loading pretrained model * Fix test * Add other keys to ignore * keys_to_ignore_at_load -> authorized_missing_keys
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
@@ -289,9 +290,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
|
||||
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
||||
derived classes of the same architecture adding modules on top of the base model.
|
||||
- **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore
|
||||
when loading the model (and avoid unnecessary warnings).
|
||||
"""
|
||||
config_class = None
|
||||
base_model_prefix = ""
|
||||
authorized_missing_keys = None
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
||||
@@ -806,9 +810,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
head_model_state_dict_without_base_prefix = [
|
||||
key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
|
||||
]
|
||||
|
||||
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
|
||||
|
||||
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
||||
# the user.
|
||||
if cls.authorized_missing_keys is not None:
|
||||
for pat in cls.authorized_missing_keys:
|
||||
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
||||
|
||||
Reference in New Issue
Block a user