Avoid unnecessary warnings when loading pretrained model (#5922)

* Avoid unnecessary warnings when loading pretrained model

* Fix test

* Add other keys to ignore

* keys_to_ignore_at_load -> authorized_missing_keys
This commit is contained in:
Sylvain Gugger
2020-07-23 18:13:36 -04:00
committed by GitHub
parent 29afb5764f
commit f5b5c5bd7e
5 changed files with 16 additions and 1 deletions

View File

@@ -938,6 +938,7 @@ class BartModel(PretrainedBartModel):
) )
class BartForConditionalGeneration(PretrainedBartModel): class BartForConditionalGeneration(PretrainedBartModel):
base_model_prefix = "model" base_model_prefix = "model"
authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
super().__init__(config) super().__init__(config)

View File

@@ -577,6 +577,8 @@ class GPT2Model(GPT2PreTrainedModel):
GPT2_START_DOCSTRING, GPT2_START_DOCSTRING,
) )
class GPT2LMHeadModel(GPT2PreTrainedModel): class GPT2LMHeadModel(GPT2PreTrainedModel):
authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.transformer = GPT2Model(config) self.transformer = GPT2Model(config)

View File

@@ -1027,6 +1027,8 @@ class T5Model(T5PreTrainedModel):
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) @add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
class T5ForConditionalGeneration(T5PreTrainedModel): class T5ForConditionalGeneration(T5PreTrainedModel):
authorized_missing_keys = [r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", r"lm_head\.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.model_dim = config.d_model self.model_dim = config.d_model

View File

@@ -17,6 +17,7 @@
import inspect import inspect
import logging import logging
import os import os
import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Tuple
@@ -289,9 +290,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model. derived classes of the same architecture adding modules on top of the base model.
- **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore
when loading the model (and avoid unnecessary warnings).
""" """
config_class = None config_class = None
base_model_prefix = "" base_model_prefix = ""
authorized_missing_keys = None
@property @property
def dummy_inputs(self) -> Dict[str, torch.Tensor]: def dummy_inputs(self) -> Dict[str, torch.Tensor]:
@@ -806,9 +810,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
head_model_state_dict_without_base_prefix = [ head_model_state_dict_without_base_prefix = [
key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys() key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
] ]
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict) missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if cls.authorized_missing_keys is not None:
for pat in cls.authorized_missing_keys:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
logger.warning( logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "

View File

@@ -311,6 +311,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
all_generative_model_classes = ( all_generative_model_classes = (
(GPT2LMHeadModel,) if is_torch_available() else () (GPT2LMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly ) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
test_missing_keys = False
def setUp(self): def setUp(self):
self.model_tester = GPT2ModelTester(self) self.model_tester = GPT2ModelTester(self)