From f5b5c5bd7e213dea1645f07902b681f88e3cf954 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 23 Jul 2020 18:13:36 -0400 Subject: [PATCH] Avoid unnecessary warnings when loading pretrained model (#5922) * Avoid unnecessary warnings when loading pretrained model * Fix test * Add other keys to ignore * keys_to_ignore_at_load -> authorized_missing_keys --- src/transformers/modeling_bart.py | 1 + src/transformers/modeling_gpt2.py | 2 ++ src/transformers/modeling_t5.py | 2 ++ src/transformers/modeling_utils.py | 11 ++++++++++- tests/test_modeling_gpt2.py | 1 + 5 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 66a6527f43..1104567a48 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -938,6 +938,7 @@ class BartModel(PretrainedBartModel): ) class BartForConditionalGeneration(PretrainedBartModel): base_model_prefix = "model" + authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"] def __init__(self, config: BartConfig): super().__init__(config) diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index 010513bc99..0514586a5f 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -577,6 +577,8 @@ class GPT2Model(GPT2PreTrainedModel): GPT2_START_DOCSTRING, ) class GPT2LMHeadModel(GPT2PreTrainedModel): + authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] + def __init__(self, config): super().__init__(config) self.transformer = GPT2Model(config) diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index 612cadf2c0..925ab53e36 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -1027,6 +1027,8 @@ class T5Model(T5PreTrainedModel): @add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) class T5ForConditionalGeneration(T5PreTrainedModel): + authorized_missing_keys = [r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", r"lm_head\.weight"] + def __init__(self, config): super().__init__(config) self.model_dim = config.d_model diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c4013c2b72..1850589cf6 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -17,6 +17,7 @@ import inspect import logging import os +import re from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple @@ -289,9 +290,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. + - **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore + when loading the model (and avoid unnecessary warnings). """ config_class = None base_model_prefix = "" + authorized_missing_keys = None @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: @@ -806,9 +810,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): head_model_state_dict_without_base_prefix = [ key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys() ] - missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict) + # Some models may have keys that are not in the state by design, removing them before needlessly warning + # the user. + if cls.authorized_missing_keys is not None: + for pat in cls.authorized_missing_keys: + missing_keys = [k for k in missing_keys if re.search(pat, k) is None] + if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index 71f8759110..b97d9d3856 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -311,6 +311,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): all_generative_model_classes = ( (GPT2LMHeadModel,) if is_torch_available() else () ) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly + test_missing_keys = False def setUp(self): self.model_tester = GPT2ModelTester(self)