consistent ignore keys + make private (#8737)
* consistent ignore keys + make private * style * - authorized_missing_keys => _keys_to_ignore_on_load_missing - authorized_unexpected_keys => _keys_to_ignore_on_load_unexpected * move public doc of private attributes to private comment
This commit is contained in:
@@ -164,9 +164,9 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
|||||||
if allow_missing_keys:
|
if allow_missing_keys:
|
||||||
missing_keys.append(name)
|
missing_keys.append(name)
|
||||||
continue
|
continue
|
||||||
elif tf_model.authorized_missing_keys is not None:
|
elif tf_model._keys_to_ignore_on_load_missing is not None:
|
||||||
# authorized missing keys don't have to be loaded
|
# authorized missing keys don't have to be loaded
|
||||||
if any(re.search(pat, name) is not None for pat in tf_model.authorized_missing_keys):
|
if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
raise AttributeError("{} not found in PyTorch model".format(name))
|
raise AttributeError("{} not found in PyTorch model".format(name))
|
||||||
@@ -209,11 +209,11 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
|||||||
|
|
||||||
unexpected_keys = list(all_pytorch_weights)
|
unexpected_keys = list(all_pytorch_weights)
|
||||||
|
|
||||||
if tf_model.authorized_missing_keys is not None:
|
if tf_model._keys_to_ignore_on_load_missing is not None:
|
||||||
for pat in tf_model.authorized_missing_keys:
|
for pat in tf_model._keys_to_ignore_on_load_missing:
|
||||||
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
||||||
if tf_model.authorized_unexpected_keys is not None:
|
if tf_model._keys_to_ignore_on_load_unexpected is not None:
|
||||||
for pat in tf_model.authorized_unexpected_keys:
|
for pat in tf_model._keys_to_ignore_on_load_unexpected:
|
||||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||||
|
|
||||||
if len(unexpected_keys) > 0:
|
if len(unexpected_keys) > 0:
|
||||||
|
|||||||
@@ -343,15 +343,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||||||
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
||||||
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
||||||
derived classes of the same architecture adding modules on top of the base model.
|
derived classes of the same architecture adding modules on top of the base model.
|
||||||
- **authorized_missing_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore
|
|
||||||
from the model when loading the model weights (and avoid unnecessary warnings).
|
|
||||||
- **authorized_unexpected_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to
|
|
||||||
ignore from the weights when loading the model weights (and avoid unnecessary warnings).
|
|
||||||
"""
|
"""
|
||||||
config_class = None
|
config_class = None
|
||||||
base_model_prefix = ""
|
base_model_prefix = ""
|
||||||
authorized_missing_keys = None
|
# a list of re pattern of tensor names to ignore from the model when loading the model weights
|
||||||
authorized_unexpected_keys = None
|
# (and avoid unnecessary warnings).
|
||||||
|
_keys_to_ignore_on_load_missing = None
|
||||||
|
# a list of re pattern of tensor names to ignore from the weights when loading the model weights
|
||||||
|
# (and avoid unnecessary warnings).
|
||||||
|
_keys_to_ignore_on_load_unexpected = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
|
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
|
||||||
@@ -742,12 +742,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||||||
|
|
||||||
model(model.dummy_inputs, training=False) # Make sure restore ops are run
|
model(model.dummy_inputs, training=False) # Make sure restore ops are run
|
||||||
|
|
||||||
if cls.authorized_missing_keys is not None:
|
if cls._keys_to_ignore_on_load_missing is not None:
|
||||||
for pat in cls.authorized_missing_keys:
|
for pat in cls._keys_to_ignore_on_load_missing:
|
||||||
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
||||||
|
|
||||||
if cls.authorized_unexpected_keys is not None:
|
if cls._keys_to_ignore_on_load_unexpected is not None:
|
||||||
for pat in cls.authorized_unexpected_keys:
|
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||||
|
|
||||||
if len(unexpected_keys) > 0:
|
if len(unexpected_keys) > 0:
|
||||||
|
|||||||
@@ -404,17 +404,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
|
|
||||||
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
||||||
derived classes of the same architecture adding modules on top of the base model.
|
derived classes of the same architecture adding modules on top of the base model.
|
||||||
- **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore
|
|
||||||
when loading the model (and avoid unnecessary warnings).
|
|
||||||
- **keys_to_never_save** (:obj:`Optional[List[str]]`) -- A list of of tensor names to ignore when saving the
|
|
||||||
model (useful for keys that aren't trained, but which are deterministic)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
config_class = None
|
config_class = None
|
||||||
base_model_prefix = ""
|
base_model_prefix = ""
|
||||||
authorized_missing_keys = None
|
# a list of re pattern of tensor names to ignore from the model when loading the model weights
|
||||||
authorized_unexpected_keys = None
|
# (and avoid unnecessary warnings).
|
||||||
keys_to_never_save = None
|
_keys_to_ignore_on_load_missing = None
|
||||||
|
# a list of re pattern of tensor names to ignore from the weights when loading the model weights
|
||||||
|
# (and avoid unnecessary warnings).
|
||||||
|
_keys_to_ignore_on_load_unexpected = None
|
||||||
|
# a list of of tensor names to ignore when saving the model (useful for keys that aren't
|
||||||
|
# trained, but which are deterministic)
|
||||||
|
_keys_to_ignore_on_save = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
||||||
@@ -719,8 +720,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
state_dict = model_to_save.state_dict()
|
state_dict = model_to_save.state_dict()
|
||||||
|
|
||||||
# Handle the case where some state_dict keys shouldn't be saved
|
# Handle the case where some state_dict keys shouldn't be saved
|
||||||
if self.keys_to_never_save is not None:
|
if self._keys_to_ignore_on_save is not None:
|
||||||
state_dict = {k: v for k, v in state_dict.items() if k not in self.keys_to_never_save}
|
state_dict = {k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save}
|
||||||
|
|
||||||
# If we save using the predefined names, we can load using `from_pretrained`
|
# If we save using the predefined names, we can load using `from_pretrained`
|
||||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
||||||
@@ -1034,12 +1035,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
|
|
||||||
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
||||||
# the user.
|
# the user.
|
||||||
if cls.authorized_missing_keys is not None:
|
if cls._keys_to_ignore_on_load_missing is not None:
|
||||||
for pat in cls.authorized_missing_keys:
|
for pat in cls._keys_to_ignore_on_load_missing:
|
||||||
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
||||||
|
|
||||||
if cls.authorized_unexpected_keys is not None:
|
if cls._keys_to_ignore_on_load_unexpected is not None:
|
||||||
for pat in cls.authorized_unexpected_keys:
|
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||||
|
|
||||||
if len(unexpected_keys) > 0:
|
if len(unexpected_keys) > 0:
|
||||||
|
|||||||
@@ -459,7 +459,7 @@ class AlbertPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
config_class = AlbertConfig
|
config_class = AlbertConfig
|
||||||
base_model_prefix = "albert"
|
base_model_prefix = "albert"
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights."""
|
"""Initialize the weights."""
|
||||||
@@ -851,7 +851,7 @@ class AlbertSOPHead(nn.Module):
|
|||||||
)
|
)
|
||||||
class AlbertForMaskedLM(AlbertPreTrainedModel):
|
class AlbertForMaskedLM(AlbertPreTrainedModel):
|
||||||
|
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1021,7 +1021,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
class AlbertForTokenClassification(AlbertPreTrainedModel):
|
class AlbertForTokenClassification(AlbertPreTrainedModel):
|
||||||
|
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1110,7 +1110,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
|
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
|
||||||
|
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -843,7 +843,7 @@ class TFAlbertSOPHead(tf.keras.layers.Layer):
|
|||||||
@add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING)
|
@add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING)
|
||||||
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||||
|
|
||||||
authorized_missing_keys = [r"pooler"]
|
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
@@ -1013,7 +1013,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
|
|||||||
)
|
)
|
||||||
class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss):
|
class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss):
|
||||||
|
|
||||||
authorized_missing_keys = [r"pooler"]
|
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
@@ -1100,7 +1100,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
|
|||||||
)
|
)
|
||||||
class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss):
|
class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss):
|
||||||
|
|
||||||
authorized_missing_keys = [r"pooler"]
|
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|||||||
@@ -946,7 +946,7 @@ class BartModel(PretrainedBartModel):
|
|||||||
)
|
)
|
||||||
class BartForConditionalGeneration(PretrainedBartModel):
|
class BartForConditionalGeneration(PretrainedBartModel):
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
|
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
|
||||||
|
|
||||||
def __init__(self, config: BartConfig):
|
def __init__(self, config: BartConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -1020,10 +1020,10 @@ class TFBartModel(TFPretrainedBartModel):
|
|||||||
)
|
)
|
||||||
class TFBartForConditionalGeneration(TFPretrainedBartModel):
|
class TFBartForConditionalGeneration(TFPretrainedBartModel):
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
authorized_missing_keys = [
|
_keys_to_ignore_on_load_missing = [
|
||||||
r"final_logits_bias",
|
r"final_logits_bias",
|
||||||
]
|
]
|
||||||
authorized_unexpected_keys = [
|
_keys_to_ignore_on_load_unexpected = [
|
||||||
r"model.encoder.embed_tokens.weight",
|
r"model.encoder.embed_tokens.weight",
|
||||||
r"model.decoder.embed_tokens.weight",
|
r"model.decoder.embed_tokens.weight",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -598,7 +598,7 @@ class BertPreTrainedModel(PreTrainedModel):
|
|||||||
config_class = BertConfig
|
config_class = BertConfig
|
||||||
load_tf_weights = load_tf_weights_in_bert
|
load_tf_weights = load_tf_weights_in_bert
|
||||||
base_model_prefix = "bert"
|
base_model_prefix = "bert"
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
""" Initialize the weights """
|
""" Initialize the weights """
|
||||||
@@ -969,8 +969,8 @@ class BertForPreTraining(BertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
class BertLMHeadModel(BertPreTrainedModel):
|
class BertLMHeadModel(BertPreTrainedModel):
|
||||||
|
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"]
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1087,8 +1087,8 @@ class BertLMHeadModel(BertPreTrainedModel):
|
|||||||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
||||||
class BertForMaskedLM(BertPreTrainedModel):
|
class BertForMaskedLM(BertPreTrainedModel):
|
||||||
|
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"]
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1469,7 +1469,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
class BertForTokenClassification(BertPreTrainedModel):
|
class BertForTokenClassification(BertPreTrainedModel):
|
||||||
|
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1560,7 +1560,7 @@ class BertForTokenClassification(BertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
class BertForQuestionAnswering(BertPreTrainedModel):
|
class BertForQuestionAnswering(BertPreTrainedModel):
|
||||||
|
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -938,8 +938,8 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
|
|||||||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
||||||
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||||
|
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
authorized_missing_keys = [r"pooler"]
|
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
@@ -1023,8 +1023,8 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
|||||||
|
|
||||||
class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||||
|
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
authorized_missing_keys = [r"pooler"]
|
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
@@ -1416,8 +1416,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
)
|
)
|
||||||
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
|
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
|
||||||
|
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
authorized_missing_keys = [r"pooler"]
|
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
@@ -1502,8 +1502,8 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
|
|||||||
)
|
)
|
||||||
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
|
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
|
||||||
|
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
authorized_missing_keys = [r"pooler"]
|
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
config_class = BertGenerationConfig
|
config_class = BertGenerationConfig
|
||||||
base_model_prefix = "bert"
|
base_model_prefix = "bert"
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
""" Initialize the weights """
|
""" Initialize the weights """
|
||||||
|
|||||||
@@ -756,7 +756,7 @@ class DebertaPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
config_class = DebertaConfig
|
config_class = DebertaConfig
|
||||||
base_model_prefix = "deberta"
|
base_model_prefix = "deberta"
|
||||||
authorized_missing_keys = ["position_ids"]
|
_keys_to_ignore_on_load_missing = ["position_ids"]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
""" Initialize the weights """
|
""" Initialize the weights """
|
||||||
|
|||||||
@@ -279,7 +279,7 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
|
|||||||
config_class = DPRConfig
|
config_class = DPRConfig
|
||||||
load_tf_weights = None
|
load_tf_weights = None
|
||||||
base_model_prefix = "ctx_encoder"
|
base_model_prefix = "ctx_encoder"
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
self.ctx_encoder.init_weights()
|
self.ctx_encoder.init_weights()
|
||||||
@@ -294,7 +294,7 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
|
|||||||
config_class = DPRConfig
|
config_class = DPRConfig
|
||||||
load_tf_weights = None
|
load_tf_weights = None
|
||||||
base_model_prefix = "question_encoder"
|
base_model_prefix = "question_encoder"
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
self.question_encoder.init_weights()
|
self.question_encoder.init_weights()
|
||||||
@@ -309,7 +309,7 @@ class DPRPretrainedReader(PreTrainedModel):
|
|||||||
config_class = DPRConfig
|
config_class = DPRConfig
|
||||||
load_tf_weights = None
|
load_tf_weights = None
|
||||||
base_model_prefix = "span_predictor"
|
base_model_prefix = "span_predictor"
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
self.span_predictor.encoder.init_weights()
|
self.span_predictor.encoder.init_weights()
|
||||||
|
|||||||
@@ -544,8 +544,8 @@ class ElectraPreTrainedModel(PreTrainedModel):
|
|||||||
config_class = ElectraConfig
|
config_class = ElectraConfig
|
||||||
load_tf_weights = load_tf_weights_in_electra
|
load_tf_weights = load_tf_weights_in_electra
|
||||||
base_model_prefix = "electra"
|
base_model_prefix = "electra"
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
authorized_unexpected_keys = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"]
|
_keys_to_ignore_on_load_unexpected = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"]
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
|
|||||||
@@ -1005,11 +1005,11 @@ class FSMTModel(PretrainedFSMTModel):
|
|||||||
)
|
)
|
||||||
class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
authorized_missing_keys = [
|
_keys_to_ignore_on_load_missing = [
|
||||||
"model.encoder.embed_positions.weight",
|
"model.encoder.embed_positions.weight",
|
||||||
"model.decoder.embed_positions.weight",
|
"model.decoder.embed_positions.weight",
|
||||||
]
|
]
|
||||||
keys_to_never_save = [
|
_keys_to_ignore_on_save = [
|
||||||
"model.encoder.embed_positions.weight",
|
"model.encoder.embed_positions.weight",
|
||||||
"model.decoder.embed_positions.weight",
|
"model.decoder.embed_positions.weight",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -780,7 +780,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
GPT2_START_DOCSTRING,
|
GPT2_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class GPT2LMHeadModel(GPT2PreTrainedModel):
|
class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||||
authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
|
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1097,7 +1097,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
GPT2_START_DOCSTRING,
|
GPT2_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class GPT2ForSequenceClassification(GPT2PreTrainedModel):
|
class GPT2ForSequenceClassification(GPT2PreTrainedModel):
|
||||||
authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
|
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -509,7 +509,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
config_class = LayoutLMConfig
|
config_class = LayoutLMConfig
|
||||||
base_model_prefix = "layoutlm"
|
base_model_prefix = "layoutlm"
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
""" Initialize the weights """
|
""" Initialize the weights """
|
||||||
|
|||||||
@@ -1303,7 +1303,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
config_class = LongformerConfig
|
config_class = LongformerConfig
|
||||||
base_model_prefix = "longformer"
|
base_model_prefix = "longformer"
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
""" Initialize the weights """
|
""" Initialize the weights """
|
||||||
@@ -1621,7 +1621,7 @@ class LongformerModel(LongformerPreTrainedModel):
|
|||||||
@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING)
|
@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING)
|
||||||
class LongformerForMaskedLM(LongformerPreTrainedModel):
|
class LongformerForMaskedLM(LongformerPreTrainedModel):
|
||||||
|
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1718,7 +1718,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
|
|||||||
)
|
)
|
||||||
class LongformerForSequenceClassification(LongformerPreTrainedModel):
|
class LongformerForSequenceClassification(LongformerPreTrainedModel):
|
||||||
|
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1827,7 +1827,7 @@ class LongformerClassificationHead(nn.Module):
|
|||||||
)
|
)
|
||||||
class LongformerForQuestionAnswering(LongformerPreTrainedModel):
|
class LongformerForQuestionAnswering(LongformerPreTrainedModel):
|
||||||
|
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1961,7 +1961,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
|
|||||||
)
|
)
|
||||||
class LongformerForTokenClassification(LongformerPreTrainedModel):
|
class LongformerForTokenClassification(LongformerPreTrainedModel):
|
||||||
|
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -1961,7 +1961,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
|
|||||||
)
|
)
|
||||||
class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss):
|
class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||||
|
|
||||||
authorized_missing_keys = [r"pooler"]
|
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
@@ -2048,7 +2048,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
|
|||||||
)
|
)
|
||||||
class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss):
|
class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss):
|
||||||
|
|
||||||
authorized_missing_keys = [r"pooler"]
|
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
@@ -2199,7 +2199,7 @@ class TFLongformerClassificationHead(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSequenceClassificationLoss):
|
class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSequenceClassificationLoss):
|
||||||
|
|
||||||
authorized_missing_keys = [r"pooler"]
|
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
@@ -2443,7 +2443,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
|
|||||||
)
|
)
|
||||||
class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenClassificationLoss):
|
class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenClassificationLoss):
|
||||||
|
|
||||||
authorized_missing_keys = [r"pooler"]
|
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|||||||
@@ -47,11 +47,11 @@ class MarianMTModel(BartForConditionalGeneration):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
config_class = MarianConfig
|
config_class = MarianConfig
|
||||||
authorized_missing_keys = [
|
_keys_to_ignore_on_load_missing = [
|
||||||
"model.encoder.embed_positions.weight",
|
"model.encoder.embed_positions.weight",
|
||||||
"model.decoder.embed_positions.weight",
|
"model.decoder.embed_positions.weight",
|
||||||
]
|
]
|
||||||
keys_to_never_save = [
|
_keys_to_ignore_on_save = [
|
||||||
"model.encoder.embed_positions.weight",
|
"model.encoder.embed_positions.weight",
|
||||||
"model.decoder.embed_positions.weight",
|
"model.decoder.embed_positions.weight",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
@add_start_docstrings("Marian model for machine translation", START_DOCSTRING)
|
@add_start_docstrings("Marian model for machine translation", START_DOCSTRING)
|
||||||
class TFMarianMTModel(TFBartForConditionalGeneration):
|
class TFMarianMTModel(TFBartForConditionalGeneration):
|
||||||
authorized_missing_keys = [
|
_keys_to_ignore_on_load_missing = [
|
||||||
r"model.encoder.embed_positions.weight",
|
r"model.encoder.embed_positions.weight",
|
||||||
r"model.decoder.embed_positions.weight",
|
r"model.decoder.embed_positions.weight",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -29,11 +29,11 @@ class MBartForConditionalGeneration(BartForConditionalGeneration):
|
|||||||
"""
|
"""
|
||||||
model_type = "mbart"
|
model_type = "mbart"
|
||||||
config_class = MBartConfig
|
config_class = MBartConfig
|
||||||
authorized_missing_keys = [
|
_keys_to_ignore_on_load_missing = [
|
||||||
"model.encoder.embed_positions.weight",
|
"model.encoder.embed_positions.weight",
|
||||||
"model.decoder.embed_positions.weight",
|
"model.decoder.embed_positions.weight",
|
||||||
]
|
]
|
||||||
keys_to_never_save = [
|
_keys_to_ignore_on_save = [
|
||||||
"model.encoder.embed_positions.weight",
|
"model.encoder.embed_positions.weight",
|
||||||
"model.decoder.embed_positions.weight",
|
"model.decoder.embed_positions.weight",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -677,7 +677,7 @@ class MobileBertPreTrainedModel(PreTrainedModel):
|
|||||||
pretrained_model_archive_map = MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
pretrained_model_archive_map = MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
load_tf_weights = load_tf_weights_in_mobilebert
|
load_tf_weights = load_tf_weights_in_mobilebert
|
||||||
base_model_prefix = "mobilebert"
|
base_model_prefix = "mobilebert"
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
""" Initialize the weights """
|
""" Initialize the weights """
|
||||||
@@ -1054,7 +1054,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
|||||||
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING)
|
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING)
|
||||||
class MobileBertForMaskedLM(MobileBertPreTrainedModel):
|
class MobileBertForMaskedLM(MobileBertPreTrainedModel):
|
||||||
|
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1350,7 +1350,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
|
class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
|
||||||
|
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1545,7 +1545,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
class MobileBertForTokenClassification(MobileBertPreTrainedModel):
|
class MobileBertForTokenClassification(MobileBertPreTrainedModel):
|
||||||
|
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -1030,7 +1030,7 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
|
|||||||
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING)
|
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING)
|
||||||
class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||||
|
|
||||||
authorized_missing_keys = [r"pooler"]
|
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
@@ -1297,7 +1297,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
|
|||||||
)
|
)
|
||||||
class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAnsweringLoss):
|
class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAnsweringLoss):
|
||||||
|
|
||||||
authorized_missing_keys = [r"pooler"]
|
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
@@ -1529,7 +1529,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
|
|||||||
)
|
)
|
||||||
class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenClassificationLoss):
|
class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenClassificationLoss):
|
||||||
|
|
||||||
authorized_missing_keys = [r"pooler"]
|
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|||||||
@@ -42,12 +42,12 @@ class MT5Model(T5Model):
|
|||||||
"""
|
"""
|
||||||
model_type = "mt5"
|
model_type = "mt5"
|
||||||
config_class = MT5Config
|
config_class = MT5Config
|
||||||
authorized_missing_keys = [
|
_keys_to_ignore_on_load_missing = [
|
||||||
r"encoder\.embed_tokens\.weight",
|
r"encoder\.embed_tokens\.weight",
|
||||||
r"decoder\.embed_tokens\.weight",
|
r"decoder\.embed_tokens\.weight",
|
||||||
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
|
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
|
||||||
]
|
]
|
||||||
keys_to_never_save = [
|
_keys_to_ignore_on_save = [
|
||||||
r"encoder\.embed_tokens\.weight",
|
r"encoder\.embed_tokens\.weight",
|
||||||
r"decoder\.embed_tokens\.weight",
|
r"decoder\.embed_tokens\.weight",
|
||||||
]
|
]
|
||||||
@@ -71,13 +71,13 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration):
|
|||||||
|
|
||||||
model_type = "mt5"
|
model_type = "mt5"
|
||||||
config_class = MT5Config
|
config_class = MT5Config
|
||||||
authorized_missing_keys = [
|
_keys_to_ignore_on_load_missing = [
|
||||||
r"encoder\.embed_tokens\.weight",
|
r"encoder\.embed_tokens\.weight",
|
||||||
r"decoder\.embed_tokens\.weight",
|
r"decoder\.embed_tokens\.weight",
|
||||||
r"lm_head\.weight",
|
r"lm_head\.weight",
|
||||||
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
|
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
|
||||||
]
|
]
|
||||||
keys_to_never_save = [
|
_keys_to_ignore_on_save = [
|
||||||
r"encoder\.embed_tokens\.weight",
|
r"encoder\.embed_tokens\.weight",
|
||||||
r"decoder\.embed_tokens\.weight",
|
r"decoder\.embed_tokens\.weight",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -279,7 +279,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
|
|||||||
config_class = OpenAIGPTConfig
|
config_class = OpenAIGPTConfig
|
||||||
load_tf_weights = load_tf_weights_in_openai_gpt
|
load_tf_weights = load_tf_weights_in_openai_gpt
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights."""
|
"""Initialize the weights."""
|
||||||
|
|||||||
@@ -46,14 +46,14 @@ class PegasusForConditionalGeneration(BartForConditionalGeneration):
|
|||||||
"""
|
"""
|
||||||
# All the code is in src/transformers/models/bart/modeling_bart.py
|
# All the code is in src/transformers/models/bart/modeling_bart.py
|
||||||
config_class = PegasusConfig
|
config_class = PegasusConfig
|
||||||
authorized_missing_keys = [
|
_keys_to_ignore_on_load_missing = [
|
||||||
r"final_logits_bias",
|
r"final_logits_bias",
|
||||||
r"encoder\.version",
|
r"encoder\.version",
|
||||||
r"decoder\.version",
|
r"decoder\.version",
|
||||||
"model.encoder.embed_positions",
|
"model.encoder.embed_positions",
|
||||||
"model.decoder.embed_positions",
|
"model.decoder.embed_positions",
|
||||||
]
|
]
|
||||||
keys_to_never_save = [
|
_keys_to_ignore_on_save = [
|
||||||
"model.encoder.embed_positions.weight",
|
"model.encoder.embed_positions.weight",
|
||||||
"model.decoder.embed_positions.weight",
|
"model.decoder.embed_positions.weight",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
@add_start_docstrings("Pegasus model for summarization", START_DOCSTRING)
|
@add_start_docstrings("Pegasus model for summarization", START_DOCSTRING)
|
||||||
class TFPegasusForConditionalGeneration(TFBartForConditionalGeneration):
|
class TFPegasusForConditionalGeneration(TFBartForConditionalGeneration):
|
||||||
authorized_missing_keys = [
|
_keys_to_ignore_on_load_missing = [
|
||||||
r"final_logits_bias",
|
r"final_logits_bias",
|
||||||
r"model.encoder.embed_positions.weight",
|
r"model.encoder.embed_positions.weight",
|
||||||
r"model.decoder.embed_positions.weight",
|
r"model.decoder.embed_positions.weight",
|
||||||
|
|||||||
@@ -216,7 +216,7 @@ class RagPreTrainedModel(PreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
config_class = RagConfig
|
config_class = RagConfig
|
||||||
base_model_prefix = "rag"
|
base_model_prefix = "rag"
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained_question_encoder_generator(
|
def from_pretrained_question_encoder_generator(
|
||||||
|
|||||||
@@ -576,7 +576,7 @@ class RobertaModel(RobertaPreTrainedModel):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
|
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
|
||||||
def __init__(self, config, add_pooling_layer=True):
|
def __init__(self, config, add_pooling_layer=True):
|
||||||
@@ -711,8 +711,8 @@ class RobertaModel(RobertaPreTrainedModel):
|
|||||||
"""RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING
|
"""RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING
|
||||||
)
|
)
|
||||||
class RobertaForCausalLM(RobertaPreTrainedModel):
|
class RobertaForCausalLM(RobertaPreTrainedModel):
|
||||||
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"]
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -829,8 +829,8 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
|
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
|
||||||
class RobertaForMaskedLM(RobertaPreTrainedModel):
|
class RobertaForMaskedLM(RobertaPreTrainedModel):
|
||||||
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"]
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -948,7 +948,7 @@ class RobertaLMHead(nn.Module):
|
|||||||
ROBERTA_START_DOCSTRING,
|
ROBERTA_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class RobertaForSequenceClassification(RobertaPreTrainedModel):
|
class RobertaForSequenceClassification(RobertaPreTrainedModel):
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1031,7 +1031,7 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
|
|||||||
ROBERTA_START_DOCSTRING,
|
ROBERTA_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class RobertaForMultipleChoice(RobertaPreTrainedModel):
|
class RobertaForMultipleChoice(RobertaPreTrainedModel):
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1123,8 +1123,8 @@ class RobertaForMultipleChoice(RobertaPreTrainedModel):
|
|||||||
ROBERTA_START_DOCSTRING,
|
ROBERTA_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class RobertaForTokenClassification(RobertaPreTrainedModel):
|
class RobertaForTokenClassification(RobertaPreTrainedModel):
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1233,8 +1233,8 @@ class RobertaClassificationHead(nn.Module):
|
|||||||
ROBERTA_START_DOCSTRING,
|
ROBERTA_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class RobertaForQuestionAnswering(RobertaPreTrainedModel):
|
class RobertaForQuestionAnswering(RobertaPreTrainedModel):
|
||||||
authorized_unexpected_keys = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -765,7 +765,7 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
|
|||||||
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
|
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
|
||||||
class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLoss):
|
class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||||
|
|
||||||
authorized_missing_keys = [r"pooler"]
|
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
@@ -877,7 +877,7 @@ class TFRobertaClassificationHead(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceClassificationLoss):
|
class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceClassificationLoss):
|
||||||
|
|
||||||
authorized_missing_keys = [r"pooler"]
|
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
@@ -1084,7 +1084,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
|
|||||||
)
|
)
|
||||||
class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassificationLoss):
|
class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassificationLoss):
|
||||||
|
|
||||||
authorized_missing_keys = [r"pooler"]
|
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
@@ -1171,7 +1171,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
|
|||||||
)
|
)
|
||||||
class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnsweringLoss):
|
class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnsweringLoss):
|
||||||
|
|
||||||
authorized_missing_keys = [r"pooler"]
|
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|||||||
@@ -428,7 +428,7 @@ class SqueezeBertPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
config_class = SqueezeBertConfig
|
config_class = SqueezeBertConfig
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
""" Initialize the weights """
|
""" Initialize the weights """
|
||||||
@@ -642,7 +642,7 @@ class SqueezeBertModel(SqueezeBertPreTrainedModel):
|
|||||||
@add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top. """, SQUEEZEBERT_START_DOCSTRING)
|
@add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top. """, SQUEEZEBERT_START_DOCSTRING)
|
||||||
class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
|
class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
|
||||||
|
|
||||||
authorized_missing_keys = [r"predictions.decoder.bias"]
|
_keys_to_ignore_on_load_missing = [r"predictions.decoder.bias"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -1086,7 +1086,7 @@ T5_INPUTS_DOCSTRING = r"""
|
|||||||
T5_START_DOCSTRING,
|
T5_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class T5Model(T5PreTrainedModel):
|
class T5Model(T5PreTrainedModel):
|
||||||
authorized_missing_keys = [
|
_keys_to_ignore_on_load_missing = [
|
||||||
r"encoder\.embed_tokens\.weight",
|
r"encoder\.embed_tokens\.weight",
|
||||||
r"decoder\.embed_tokens\.weight",
|
r"decoder\.embed_tokens\.weight",
|
||||||
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
|
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
|
||||||
@@ -1258,7 +1258,7 @@ class T5Model(T5PreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
|
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
|
||||||
class T5ForConditionalGeneration(T5PreTrainedModel):
|
class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||||
authorized_missing_keys = [
|
_keys_to_ignore_on_load_missing = [
|
||||||
r"encoder\.embed_tokens\.weight",
|
r"encoder\.embed_tokens\.weight",
|
||||||
r"decoder\.embed_tokens\.weight",
|
r"decoder\.embed_tokens\.weight",
|
||||||
r"lm_head\.weight",
|
r"lm_head\.weight",
|
||||||
|
|||||||
@@ -399,7 +399,7 @@ XLM_INPUTS_DOCSTRING = r"""
|
|||||||
XLM_START_DOCSTRING,
|
XLM_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class XLMModel(XLMPreTrainedModel):
|
class XLMModel(XLMPreTrainedModel):
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -540,7 +540,7 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
|
|||||||
config_class = {{cookiecutter.camelcase_modelname}}Config
|
config_class = {{cookiecutter.camelcase_modelname}}Config
|
||||||
load_tf_weights = load_tf_weights_in_{{cookiecutter.lowercase_modelname}}
|
load_tf_weights = load_tf_weights_in_{{cookiecutter.lowercase_modelname}}
|
||||||
base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
|
base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
|
||||||
authorized_missing_keys = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
""" Initialize the weights """
|
""" Initialize the weights """
|
||||||
|
|||||||
@@ -135,17 +135,17 @@ class ModelTesterMixin:
|
|||||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||||
self.assertLessEqual(max_diff, 1e-5)
|
self.assertLessEqual(max_diff, 1e-5)
|
||||||
|
|
||||||
def test_save_load_keys_to_never_save(self):
|
def test_save_load__keys_to_ignore_on_save(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
keys_to_never_save = getattr(model, "keys_to_never_save", None)
|
_keys_to_ignore_on_save = getattr(model, "_keys_to_ignore_on_save", None)
|
||||||
if keys_to_never_save is None:
|
if _keys_to_ignore_on_save is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# check the keys are in the original state_dict
|
# check the keys are in the original state_dict
|
||||||
for k in keys_to_never_save:
|
for k in _keys_to_ignore_on_save:
|
||||||
self.assertIn(k, model.state_dict())
|
self.assertIn(k, model.state_dict())
|
||||||
|
|
||||||
# check that certain keys didn't get saved with the model
|
# check that certain keys didn't get saved with the model
|
||||||
@@ -153,7 +153,7 @@ class ModelTesterMixin:
|
|||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
|
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
|
||||||
state_dict_saved = torch.load(output_model_file)
|
state_dict_saved = torch.load(output_model_file)
|
||||||
for k in keys_to_never_save:
|
for k in _keys_to_ignore_on_save:
|
||||||
self.assertNotIn(k, state_dict_saved)
|
self.assertNotIn(k, state_dict_saved)
|
||||||
|
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ class ModelTester:
|
|||||||
class SelectiveCommonTest(unittest.TestCase):
|
class SelectiveCommonTest(unittest.TestCase):
|
||||||
all_model_classes = (MarianMTModel,) if is_torch_available() else ()
|
all_model_classes = (MarianMTModel,) if is_torch_available() else ()
|
||||||
|
|
||||||
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save
|
test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = ModelTester(self)
|
self.model_tester = ModelTester(self)
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class ModelTester:
|
|||||||
class SelectiveCommonTest(unittest.TestCase):
|
class SelectiveCommonTest(unittest.TestCase):
|
||||||
all_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
|
all_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
|
||||||
|
|
||||||
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save
|
test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = ModelTester(self)
|
self.model_tester = ModelTester(self)
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class ModelTester:
|
|||||||
class SelectiveCommonTest(unittest.TestCase):
|
class SelectiveCommonTest(unittest.TestCase):
|
||||||
all_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
|
all_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
|
||||||
|
|
||||||
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save
|
test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = ModelTester(self)
|
self.model_tester = ModelTester(self)
|
||||||
|
|||||||
Reference in New Issue
Block a user