Avoid unnecessary warnings when loading pretrained model (#5922)
* Avoid unnecessary warnings when loading pretrained model * Fix test * Add other keys to ignore * keys_to_ignore_at_load -> authorized_missing_keys
This commit is contained in:
@@ -938,6 +938,7 @@ class BartModel(PretrainedBartModel):
|
|||||||
)
|
)
|
||||||
class BartForConditionalGeneration(PretrainedBartModel):
|
class BartForConditionalGeneration(PretrainedBartModel):
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
|
authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
|
||||||
|
|
||||||
def __init__(self, config: BartConfig):
|
def __init__(self, config: BartConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -577,6 +577,8 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
GPT2_START_DOCSTRING,
|
GPT2_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class GPT2LMHeadModel(GPT2PreTrainedModel):
|
class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||||
|
authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.transformer = GPT2Model(config)
|
self.transformer = GPT2Model(config)
|
||||||
|
|||||||
@@ -1027,6 +1027,8 @@ class T5Model(T5PreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
|
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
|
||||||
class T5ForConditionalGeneration(T5PreTrainedModel):
|
class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||||
|
authorized_missing_keys = [r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", r"lm_head\.weight"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.model_dim = config.d_model
|
self.model_dim = config.d_model
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable, Dict, List, Optional, Tuple
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@@ -289,9 +290,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
|
|
||||||
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
||||||
derived classes of the same architecture adding modules on top of the base model.
|
derived classes of the same architecture adding modules on top of the base model.
|
||||||
|
- **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore
|
||||||
|
when loading the model (and avoid unnecessary warnings).
|
||||||
"""
|
"""
|
||||||
config_class = None
|
config_class = None
|
||||||
base_model_prefix = ""
|
base_model_prefix = ""
|
||||||
|
authorized_missing_keys = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
||||||
@@ -806,9 +810,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
head_model_state_dict_without_base_prefix = [
|
head_model_state_dict_without_base_prefix = [
|
||||||
key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
|
key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
|
||||||
]
|
]
|
||||||
|
|
||||||
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
|
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
|
||||||
|
|
||||||
|
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
||||||
|
# the user.
|
||||||
|
if cls.authorized_missing_keys is not None:
|
||||||
|
for pat in cls.authorized_missing_keys:
|
||||||
|
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
||||||
|
|
||||||
if len(unexpected_keys) > 0:
|
if len(unexpected_keys) > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
||||||
|
|||||||
@@ -311,6 +311,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
all_generative_model_classes = (
|
all_generative_model_classes = (
|
||||||
(GPT2LMHeadModel,) if is_torch_available() else ()
|
(GPT2LMHeadModel,) if is_torch_available() else ()
|
||||||
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
|
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
|
||||||
|
test_missing_keys = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = GPT2ModelTester(self)
|
self.model_tester = GPT2ModelTester(self)
|
||||||
|
|||||||
Reference in New Issue
Block a user