consistent ignore keys + make private (#8737)

* consistent ignore keys + make private

* style

* - authorized_missing_keys    => _keys_to_ignore_on_load_missing
  - authorized_unexpected_keys => _keys_to_ignore_on_load_unexpected

* move public doc of private attributes to private comment
This commit is contained in:
Stas Bekman
2020-11-23 12:33:13 -08:00
committed by GitHub
parent 49759c0cda
commit e84786aaa6
38 changed files with 127 additions and 126 deletions

View File

@@ -164,9 +164,9 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
if allow_missing_keys: if allow_missing_keys:
missing_keys.append(name) missing_keys.append(name)
continue continue
elif tf_model.authorized_missing_keys is not None: elif tf_model._keys_to_ignore_on_load_missing is not None:
# authorized missing keys don't have to be loaded # authorized missing keys don't have to be loaded
if any(re.search(pat, name) is not None for pat in tf_model.authorized_missing_keys): if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing):
continue continue
raise AttributeError("{} not found in PyTorch model".format(name)) raise AttributeError("{} not found in PyTorch model".format(name))
@@ -209,11 +209,11 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
unexpected_keys = list(all_pytorch_weights) unexpected_keys = list(all_pytorch_weights)
if tf_model.authorized_missing_keys is not None: if tf_model._keys_to_ignore_on_load_missing is not None:
for pat in tf_model.authorized_missing_keys: for pat in tf_model._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None] missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if tf_model.authorized_unexpected_keys is not None: if tf_model._keys_to_ignore_on_load_unexpected is not None:
for pat in tf_model.authorized_unexpected_keys: for pat in tf_model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:

View File

@@ -343,15 +343,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model. derived classes of the same architecture adding modules on top of the base model.
- **authorized_missing_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore
from the model when loading the model weights (and avoid unnecessary warnings).
- **authorized_unexpected_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to
ignore from the weights when loading the model weights (and avoid unnecessary warnings).
""" """
config_class = None config_class = None
base_model_prefix = "" base_model_prefix = ""
authorized_missing_keys = None # a list of re pattern of tensor names to ignore from the model when loading the model weights
authorized_unexpected_keys = None # (and avoid unnecessary warnings).
_keys_to_ignore_on_load_missing = None
# a list of re pattern of tensor names to ignore from the weights when loading the model weights
# (and avoid unnecessary warnings).
_keys_to_ignore_on_load_unexpected = None
@property @property
def dummy_inputs(self) -> Dict[str, tf.Tensor]: def dummy_inputs(self) -> Dict[str, tf.Tensor]:
@@ -742,12 +742,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
model(model.dummy_inputs, training=False) # Make sure restore ops are run model(model.dummy_inputs, training=False) # Make sure restore ops are run
if cls.authorized_missing_keys is not None: if cls._keys_to_ignore_on_load_missing is not None:
for pat in cls.authorized_missing_keys: for pat in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None] missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if cls.authorized_unexpected_keys is not None: if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls.authorized_unexpected_keys: for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:

View File

@@ -404,17 +404,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model. derived classes of the same architecture adding modules on top of the base model.
- **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore
when loading the model (and avoid unnecessary warnings).
- **keys_to_never_save** (:obj:`Optional[List[str]]`) -- A list of of tensor names to ignore when saving the
model (useful for keys that aren't trained, but which are deterministic)
""" """
config_class = None config_class = None
base_model_prefix = "" base_model_prefix = ""
authorized_missing_keys = None # a list of re pattern of tensor names to ignore from the model when loading the model weights
authorized_unexpected_keys = None # (and avoid unnecessary warnings).
keys_to_never_save = None _keys_to_ignore_on_load_missing = None
# a list of re pattern of tensor names to ignore from the weights when loading the model weights
# (and avoid unnecessary warnings).
_keys_to_ignore_on_load_unexpected = None
# a list of of tensor names to ignore when saving the model (useful for keys that aren't
# trained, but which are deterministic)
_keys_to_ignore_on_save = None
@property @property
def dummy_inputs(self) -> Dict[str, torch.Tensor]: def dummy_inputs(self) -> Dict[str, torch.Tensor]:
@@ -719,8 +720,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
state_dict = model_to_save.state_dict() state_dict = model_to_save.state_dict()
# Handle the case where some state_dict keys shouldn't be saved # Handle the case where some state_dict keys shouldn't be saved
if self.keys_to_never_save is not None: if self._keys_to_ignore_on_save is not None:
state_dict = {k: v for k, v in state_dict.items() if k not in self.keys_to_never_save} state_dict = {k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save}
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, WEIGHTS_NAME) output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
@@ -1034,12 +1035,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# Some models may have keys that are not in the state by design, removing them before needlessly warning # Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user. # the user.
if cls.authorized_missing_keys is not None: if cls._keys_to_ignore_on_load_missing is not None:
for pat in cls.authorized_missing_keys: for pat in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None] missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if cls.authorized_unexpected_keys is not None: if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls.authorized_unexpected_keys: for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:

View File

@@ -459,7 +459,7 @@ class AlbertPreTrainedModel(PreTrainedModel):
config_class = AlbertConfig config_class = AlbertConfig
base_model_prefix = "albert" base_model_prefix = "albert"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights.""" """Initialize the weights."""
@@ -851,7 +851,7 @@ class AlbertSOPHead(nn.Module):
) )
class AlbertForMaskedLM(AlbertPreTrainedModel): class AlbertForMaskedLM(AlbertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@@ -1021,7 +1021,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
) )
class AlbertForTokenClassification(AlbertPreTrainedModel): class AlbertForTokenClassification(AlbertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@@ -1110,7 +1110,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
) )
class AlbertForQuestionAnswering(AlbertPreTrainedModel): class AlbertForQuestionAnswering(AlbertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -843,7 +843,7 @@ class TFAlbertSOPHead(tf.keras.layers.Layer):
@add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING) @add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING)
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss): class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
@@ -1013,7 +1013,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
) )
class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss): class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
@@ -1100,7 +1100,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
) )
class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss): class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)

View File

@@ -946,7 +946,7 @@ class BartModel(PretrainedBartModel):
) )
class BartForConditionalGeneration(PretrainedBartModel): class BartForConditionalGeneration(PretrainedBartModel):
base_model_prefix = "model" base_model_prefix = "model"
authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"] _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
super().__init__(config) super().__init__(config)

View File

@@ -1020,10 +1020,10 @@ class TFBartModel(TFPretrainedBartModel):
) )
class TFBartForConditionalGeneration(TFPretrainedBartModel): class TFBartForConditionalGeneration(TFPretrainedBartModel):
base_model_prefix = "model" base_model_prefix = "model"
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
r"final_logits_bias", r"final_logits_bias",
] ]
authorized_unexpected_keys = [ _keys_to_ignore_on_load_unexpected = [
r"model.encoder.embed_tokens.weight", r"model.encoder.embed_tokens.weight",
r"model.decoder.embed_tokens.weight", r"model.decoder.embed_tokens.weight",
] ]

View File

@@ -598,7 +598,7 @@ class BertPreTrainedModel(PreTrainedModel):
config_class = BertConfig config_class = BertConfig
load_tf_weights = load_tf_weights_in_bert load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert" base_model_prefix = "bert"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
@@ -969,8 +969,8 @@ class BertForPreTraining(BertPreTrainedModel):
) )
class BertLMHeadModel(BertPreTrainedModel): class BertLMHeadModel(BertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@@ -1087,8 +1087,8 @@ class BertLMHeadModel(BertPreTrainedModel):
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class BertForMaskedLM(BertPreTrainedModel): class BertForMaskedLM(BertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@@ -1469,7 +1469,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
) )
class BertForTokenClassification(BertPreTrainedModel): class BertForTokenClassification(BertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@@ -1560,7 +1560,7 @@ class BertForTokenClassification(BertPreTrainedModel):
) )
class BertForQuestionAnswering(BertPreTrainedModel): class BertForQuestionAnswering(BertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -938,8 +938,8 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
@@ -1023,8 +1023,8 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
@@ -1416,8 +1416,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
) )
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss): class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
@@ -1502,8 +1502,8 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
) )
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss): class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)

View File

@@ -173,7 +173,7 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
config_class = BertGenerationConfig config_class = BertGenerationConfig
base_model_prefix = "bert" base_model_prefix = "bert"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """

View File

@@ -756,7 +756,7 @@ class DebertaPreTrainedModel(PreTrainedModel):
config_class = DebertaConfig config_class = DebertaConfig
base_model_prefix = "deberta" base_model_prefix = "deberta"
authorized_missing_keys = ["position_ids"] _keys_to_ignore_on_load_missing = ["position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """

View File

@@ -279,7 +279,7 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
config_class = DPRConfig config_class = DPRConfig
load_tf_weights = None load_tf_weights = None
base_model_prefix = "ctx_encoder" base_model_prefix = "ctx_encoder"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def init_weights(self): def init_weights(self):
self.ctx_encoder.init_weights() self.ctx_encoder.init_weights()
@@ -294,7 +294,7 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
config_class = DPRConfig config_class = DPRConfig
load_tf_weights = None load_tf_weights = None
base_model_prefix = "question_encoder" base_model_prefix = "question_encoder"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def init_weights(self): def init_weights(self):
self.question_encoder.init_weights() self.question_encoder.init_weights()
@@ -309,7 +309,7 @@ class DPRPretrainedReader(PreTrainedModel):
config_class = DPRConfig config_class = DPRConfig
load_tf_weights = None load_tf_weights = None
base_model_prefix = "span_predictor" base_model_prefix = "span_predictor"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def init_weights(self): def init_weights(self):
self.span_predictor.encoder.init_weights() self.span_predictor.encoder.init_weights()

View File

@@ -544,8 +544,8 @@ class ElectraPreTrainedModel(PreTrainedModel):
config_class = ElectraConfig config_class = ElectraConfig
load_tf_weights = load_tf_weights_in_electra load_tf_weights = load_tf_weights_in_electra
base_model_prefix = "electra" base_model_prefix = "electra"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
authorized_unexpected_keys = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"] _keys_to_ignore_on_load_unexpected = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):

View File

@@ -1005,11 +1005,11 @@ class FSMTModel(PretrainedFSMTModel):
) )
class FSMTForConditionalGeneration(PretrainedFSMTModel): class FSMTForConditionalGeneration(PretrainedFSMTModel):
base_model_prefix = "model" base_model_prefix = "model"
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
"model.encoder.embed_positions.weight", "model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight", "model.decoder.embed_positions.weight",
] ]
keys_to_never_save = [ _keys_to_ignore_on_save = [
"model.encoder.embed_positions.weight", "model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight", "model.decoder.embed_positions.weight",
] ]

View File

@@ -780,7 +780,7 @@ class GPT2Model(GPT2PreTrainedModel):
GPT2_START_DOCSTRING, GPT2_START_DOCSTRING,
) )
class GPT2LMHeadModel(GPT2PreTrainedModel): class GPT2LMHeadModel(GPT2PreTrainedModel):
authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@@ -1097,7 +1097,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
GPT2_START_DOCSTRING, GPT2_START_DOCSTRING,
) )
class GPT2ForSequenceClassification(GPT2PreTrainedModel): class GPT2ForSequenceClassification(GPT2PreTrainedModel):
authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -509,7 +509,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
config_class = LayoutLMConfig config_class = LayoutLMConfig
base_model_prefix = "layoutlm" base_model_prefix = "layoutlm"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """

View File

@@ -1303,7 +1303,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
config_class = LongformerConfig config_class = LongformerConfig
base_model_prefix = "longformer" base_model_prefix = "longformer"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
@@ -1621,7 +1621,7 @@ class LongformerModel(LongformerPreTrainedModel):
@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING) @add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING)
class LongformerForMaskedLM(LongformerPreTrainedModel): class LongformerForMaskedLM(LongformerPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@@ -1718,7 +1718,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
) )
class LongformerForSequenceClassification(LongformerPreTrainedModel): class LongformerForSequenceClassification(LongformerPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@@ -1827,7 +1827,7 @@ class LongformerClassificationHead(nn.Module):
) )
class LongformerForQuestionAnswering(LongformerPreTrainedModel): class LongformerForQuestionAnswering(LongformerPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@@ -1961,7 +1961,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
) )
class LongformerForTokenClassification(LongformerPreTrainedModel): class LongformerForTokenClassification(LongformerPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -1961,7 +1961,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
) )
class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss): class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
@@ -2048,7 +2048,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
) )
class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss): class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
@@ -2199,7 +2199,7 @@ class TFLongformerClassificationHead(tf.keras.layers.Layer):
) )
class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSequenceClassificationLoss): class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSequenceClassificationLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
@@ -2443,7 +2443,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
) )
class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenClassificationLoss): class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenClassificationLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)

View File

@@ -47,11 +47,11 @@ class MarianMTModel(BartForConditionalGeneration):
""" """
config_class = MarianConfig config_class = MarianConfig
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
"model.encoder.embed_positions.weight", "model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight", "model.decoder.embed_positions.weight",
] ]
keys_to_never_save = [ _keys_to_ignore_on_save = [
"model.encoder.embed_positions.weight", "model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight", "model.decoder.embed_positions.weight",
] ]

View File

@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__)
@add_start_docstrings("Marian model for machine translation", START_DOCSTRING) @add_start_docstrings("Marian model for machine translation", START_DOCSTRING)
class TFMarianMTModel(TFBartForConditionalGeneration): class TFMarianMTModel(TFBartForConditionalGeneration):
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
r"model.encoder.embed_positions.weight", r"model.encoder.embed_positions.weight",
r"model.decoder.embed_positions.weight", r"model.decoder.embed_positions.weight",
] ]

View File

@@ -29,11 +29,11 @@ class MBartForConditionalGeneration(BartForConditionalGeneration):
""" """
model_type = "mbart" model_type = "mbart"
config_class = MBartConfig config_class = MBartConfig
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
"model.encoder.embed_positions.weight", "model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight", "model.decoder.embed_positions.weight",
] ]
keys_to_never_save = [ _keys_to_ignore_on_save = [
"model.encoder.embed_positions.weight", "model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight", "model.decoder.embed_positions.weight",
] ]

View File

@@ -677,7 +677,7 @@ class MobileBertPreTrainedModel(PreTrainedModel):
pretrained_model_archive_map = MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST pretrained_model_archive_map = MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST
load_tf_weights = load_tf_weights_in_mobilebert load_tf_weights = load_tf_weights_in_mobilebert
base_model_prefix = "mobilebert" base_model_prefix = "mobilebert"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
@@ -1054,7 +1054,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING) @add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING)
class MobileBertForMaskedLM(MobileBertPreTrainedModel): class MobileBertForMaskedLM(MobileBertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@@ -1350,7 +1350,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
) )
class MobileBertForQuestionAnswering(MobileBertPreTrainedModel): class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@@ -1545,7 +1545,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
) )
class MobileBertForTokenClassification(MobileBertPreTrainedModel): class MobileBertForTokenClassification(MobileBertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -1030,7 +1030,7 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING) @add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING)
class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModelingLoss): class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModelingLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
@@ -1297,7 +1297,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
) )
class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAnsweringLoss): class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAnsweringLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
@@ -1529,7 +1529,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
) )
class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenClassificationLoss): class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenClassificationLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)

View File

@@ -42,12 +42,12 @@ class MT5Model(T5Model):
""" """
model_type = "mt5" model_type = "mt5"
config_class = MT5Config config_class = MT5Config
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
r"encoder\.embed_tokens\.weight", r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
] ]
keys_to_never_save = [ _keys_to_ignore_on_save = [
r"encoder\.embed_tokens\.weight", r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight",
] ]
@@ -71,13 +71,13 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration):
model_type = "mt5" model_type = "mt5"
config_class = MT5Config config_class = MT5Config
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
r"encoder\.embed_tokens\.weight", r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight",
r"lm_head\.weight", r"lm_head\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
] ]
keys_to_never_save = [ _keys_to_ignore_on_save = [
r"encoder\.embed_tokens\.weight", r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight",
] ]

View File

@@ -279,7 +279,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
config_class = OpenAIGPTConfig config_class = OpenAIGPTConfig
load_tf_weights = load_tf_weights_in_openai_gpt load_tf_weights = load_tf_weights_in_openai_gpt
base_model_prefix = "transformer" base_model_prefix = "transformer"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights.""" """Initialize the weights."""

View File

@@ -46,14 +46,14 @@ class PegasusForConditionalGeneration(BartForConditionalGeneration):
""" """
# All the code is in src/transformers/models/bart/modeling_bart.py # All the code is in src/transformers/models/bart/modeling_bart.py
config_class = PegasusConfig config_class = PegasusConfig
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
r"final_logits_bias", r"final_logits_bias",
r"encoder\.version", r"encoder\.version",
r"decoder\.version", r"decoder\.version",
"model.encoder.embed_positions", "model.encoder.embed_positions",
"model.decoder.embed_positions", "model.decoder.embed_positions",
] ]
keys_to_never_save = [ _keys_to_ignore_on_save = [
"model.encoder.embed_positions.weight", "model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight", "model.decoder.embed_positions.weight",
] ]

View File

@@ -32,7 +32,7 @@ logger = logging.get_logger(__name__)
@add_start_docstrings("Pegasus model for summarization", START_DOCSTRING) @add_start_docstrings("Pegasus model for summarization", START_DOCSTRING)
class TFPegasusForConditionalGeneration(TFBartForConditionalGeneration): class TFPegasusForConditionalGeneration(TFBartForConditionalGeneration):
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
r"final_logits_bias", r"final_logits_bias",
r"model.encoder.embed_positions.weight", r"model.encoder.embed_positions.weight",
r"model.decoder.embed_positions.weight", r"model.decoder.embed_positions.weight",

View File

@@ -216,7 +216,7 @@ class RagPreTrainedModel(PreTrainedModel):
""" """
config_class = RagConfig config_class = RagConfig
base_model_prefix = "rag" base_model_prefix = "rag"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
@classmethod @classmethod
def from_pretrained_question_encoder_generator( def from_pretrained_question_encoder_generator(

View File

@@ -576,7 +576,7 @@ class RobertaModel(RobertaPreTrainedModel):
""" """
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
def __init__(self, config, add_pooling_layer=True): def __init__(self, config, add_pooling_layer=True):
@@ -711,8 +711,8 @@ class RobertaModel(RobertaPreTrainedModel):
"""RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING """RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING
) )
class RobertaForCausalLM(RobertaPreTrainedModel): class RobertaForCausalLM(RobertaPreTrainedModel):
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@@ -829,8 +829,8 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING) @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
class RobertaForMaskedLM(RobertaPreTrainedModel): class RobertaForMaskedLM(RobertaPreTrainedModel):
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@@ -948,7 +948,7 @@ class RobertaLMHead(nn.Module):
ROBERTA_START_DOCSTRING, ROBERTA_START_DOCSTRING,
) )
class RobertaForSequenceClassification(RobertaPreTrainedModel): class RobertaForSequenceClassification(RobertaPreTrainedModel):
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@@ -1031,7 +1031,7 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
ROBERTA_START_DOCSTRING, ROBERTA_START_DOCSTRING,
) )
class RobertaForMultipleChoice(RobertaPreTrainedModel): class RobertaForMultipleChoice(RobertaPreTrainedModel):
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@@ -1123,8 +1123,8 @@ class RobertaForMultipleChoice(RobertaPreTrainedModel):
ROBERTA_START_DOCSTRING, ROBERTA_START_DOCSTRING,
) )
class RobertaForTokenClassification(RobertaPreTrainedModel): class RobertaForTokenClassification(RobertaPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@@ -1233,8 +1233,8 @@ class RobertaClassificationHead(nn.Module):
ROBERTA_START_DOCSTRING, ROBERTA_START_DOCSTRING,
) )
class RobertaForQuestionAnswering(RobertaPreTrainedModel): class RobertaForQuestionAnswering(RobertaPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -765,7 +765,7 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING) @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLoss): class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
@@ -877,7 +877,7 @@ class TFRobertaClassificationHead(tf.keras.layers.Layer):
) )
class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceClassificationLoss): class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceClassificationLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
@@ -1084,7 +1084,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
) )
class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassificationLoss): class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassificationLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
@@ -1171,7 +1171,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
) )
class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnsweringLoss): class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnsweringLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)

View File

@@ -428,7 +428,7 @@ class SqueezeBertPreTrainedModel(PreTrainedModel):
config_class = SqueezeBertConfig config_class = SqueezeBertConfig
base_model_prefix = "transformer" base_model_prefix = "transformer"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
@@ -642,7 +642,7 @@ class SqueezeBertModel(SqueezeBertPreTrainedModel):
@add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top. """, SQUEEZEBERT_START_DOCSTRING) @add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top. """, SQUEEZEBERT_START_DOCSTRING)
class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel): class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
authorized_missing_keys = [r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -1086,7 +1086,7 @@ T5_INPUTS_DOCSTRING = r"""
T5_START_DOCSTRING, T5_START_DOCSTRING,
) )
class T5Model(T5PreTrainedModel): class T5Model(T5PreTrainedModel):
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
r"encoder\.embed_tokens\.weight", r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
@@ -1258,7 +1258,7 @@ class T5Model(T5PreTrainedModel):
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) @add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
class T5ForConditionalGeneration(T5PreTrainedModel): class T5ForConditionalGeneration(T5PreTrainedModel):
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
r"encoder\.embed_tokens\.weight", r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight",
r"lm_head\.weight", r"lm_head\.weight",

View File

@@ -399,7 +399,7 @@ XLM_INPUTS_DOCSTRING = r"""
XLM_START_DOCSTRING, XLM_START_DOCSTRING,
) )
class XLMModel(XLMPreTrainedModel): class XLMModel(XLMPreTrainedModel):
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -540,7 +540,7 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
config_class = {{cookiecutter.camelcase_modelname}}Config config_class = {{cookiecutter.camelcase_modelname}}Config
load_tf_weights = load_tf_weights_in_{{cookiecutter.lowercase_modelname}} load_tf_weights = load_tf_weights_in_{{cookiecutter.lowercase_modelname}}
base_model_prefix = "{{cookiecutter.lowercase_modelname}}" base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """

View File

@@ -135,17 +135,17 @@ class ModelTesterMixin:
max_diff = np.amax(np.abs(out_1 - out_2)) max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5) self.assertLessEqual(max_diff, 1e-5)
def test_save_load_keys_to_never_save(self): def test_save_load__keys_to_ignore_on_save(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
keys_to_never_save = getattr(model, "keys_to_never_save", None) _keys_to_ignore_on_save = getattr(model, "_keys_to_ignore_on_save", None)
if keys_to_never_save is None: if _keys_to_ignore_on_save is None:
continue continue
# check the keys are in the original state_dict # check the keys are in the original state_dict
for k in keys_to_never_save: for k in _keys_to_ignore_on_save:
self.assertIn(k, model.state_dict()) self.assertIn(k, model.state_dict())
# check that certain keys didn't get saved with the model # check that certain keys didn't get saved with the model
@@ -153,7 +153,7 @@ class ModelTesterMixin:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME) output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
state_dict_saved = torch.load(output_model_file) state_dict_saved = torch.load(output_model_file)
for k in keys_to_never_save: for k in _keys_to_ignore_on_save:
self.assertNotIn(k, state_dict_saved) self.assertNotIn(k, state_dict_saved)
def test_initialization(self): def test_initialization(self):

View File

@@ -60,7 +60,7 @@ class ModelTester:
class SelectiveCommonTest(unittest.TestCase): class SelectiveCommonTest(unittest.TestCase):
all_model_classes = (MarianMTModel,) if is_torch_available() else () all_model_classes = (MarianMTModel,) if is_torch_available() else ()
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save
def setUp(self): def setUp(self):
self.model_tester = ModelTester(self) self.model_tester = ModelTester(self)

View File

@@ -47,7 +47,7 @@ class ModelTester:
class SelectiveCommonTest(unittest.TestCase): class SelectiveCommonTest(unittest.TestCase):
all_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else () all_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save
def setUp(self): def setUp(self):
self.model_tester = ModelTester(self) self.model_tester = ModelTester(self)

View File

@@ -43,7 +43,7 @@ class ModelTester:
class SelectiveCommonTest(unittest.TestCase): class SelectiveCommonTest(unittest.TestCase):
all_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else () all_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save
def setUp(self): def setUp(self):
self.model_tester = ModelTester(self) self.model_tester = ModelTester(self)