From b7345d22d0b59ccfda8df840a918af33cf95a189 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 27 Jul 2020 20:00:44 -0400 Subject: [PATCH] [fix] no warning for position_ids buffer (#6063) --- src/transformers/modeling_bert.py | 2 ++ src/transformers/modeling_mobilebert.py | 2 ++ src/transformers/modeling_openai.py | 1 + src/transformers/modeling_xlm.py | 4 +++- 4 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index e27ba7539c..757eb7c9c7 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -699,6 +699,8 @@ class BertModel(BertPreTrainedModel): """ + authorized_missing_keys = [r"position_ids"] + def __init__(self, config): super().__init__(config) self.config = config diff --git a/src/transformers/modeling_mobilebert.py b/src/transformers/modeling_mobilebert.py index b01c29df29..4d78ca0396 100644 --- a/src/transformers/modeling_mobilebert.py +++ b/src/transformers/modeling_mobilebert.py @@ -788,6 +788,8 @@ class MobileBertModel(MobileBertPreTrainedModel): https://arxiv.org/pdf/2004.02984.pdf """ + authorized_missing_keys = [r"position_ids"] + def __init__(self, config): super().__init__(config) self.config = config diff --git a/src/transformers/modeling_openai.py b/src/transformers/modeling_openai.py index e346219c3d..3efa7d353f 100644 --- a/src/transformers/modeling_openai.py +++ b/src/transformers/modeling_openai.py @@ -272,6 +272,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel): config_class = OpenAIGPTConfig load_tf_weights = load_tf_weights_in_openai_gpt base_model_prefix = "transformer" + authorized_missing_keys = [r"position_ids"] def _init_weights(self, module): """ Initialize the weights. diff --git a/src/transformers/modeling_xlm.py b/src/transformers/modeling_xlm.py index 932bf807a5..e7396df689 100644 --- a/src/transformers/modeling_xlm.py +++ b/src/transformers/modeling_xlm.py @@ -375,7 +375,9 @@ XLM_INPUTS_DOCSTRING = r""" XLM_START_DOCSTRING, ) class XLMModel(XLMPreTrainedModel): - def __init__(self, config): # , dico, is_encoder, with_output): + authorized_missing_keys = [r"position_ids"] + + def __init__(self, config): super().__init__(config) # encoder / decoder, output layer