consistent ignore keys + make private (#8737)
* consistent ignore keys + make private * style * - authorized_missing_keys => _keys_to_ignore_on_load_missing - authorized_unexpected_keys => _keys_to_ignore_on_load_unexpected * move public doc of private attributes to private comment
This commit is contained in:
@@ -404,17 +404,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
|
||||
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
||||
derived classes of the same architecture adding modules on top of the base model.
|
||||
- **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore
|
||||
when loading the model (and avoid unnecessary warnings).
|
||||
- **keys_to_never_save** (:obj:`Optional[List[str]]`) -- A list of of tensor names to ignore when saving the
|
||||
model (useful for keys that aren't trained, but which are deterministic)
|
||||
|
||||
"""
|
||||
config_class = None
|
||||
base_model_prefix = ""
|
||||
authorized_missing_keys = None
|
||||
authorized_unexpected_keys = None
|
||||
keys_to_never_save = None
|
||||
# a list of re pattern of tensor names to ignore from the model when loading the model weights
|
||||
# (and avoid unnecessary warnings).
|
||||
_keys_to_ignore_on_load_missing = None
|
||||
# a list of re pattern of tensor names to ignore from the weights when loading the model weights
|
||||
# (and avoid unnecessary warnings).
|
||||
_keys_to_ignore_on_load_unexpected = None
|
||||
# a list of of tensor names to ignore when saving the model (useful for keys that aren't
|
||||
# trained, but which are deterministic)
|
||||
_keys_to_ignore_on_save = None
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
||||
@@ -719,8 +720,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
# Handle the case where some state_dict keys shouldn't be saved
|
||||
if self.keys_to_never_save is not None:
|
||||
state_dict = {k: v for k, v in state_dict.items() if k not in self.keys_to_never_save}
|
||||
if self._keys_to_ignore_on_save is not None:
|
||||
state_dict = {k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save}
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
||||
@@ -1034,12 +1035,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
|
||||
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
||||
# the user.
|
||||
if cls.authorized_missing_keys is not None:
|
||||
for pat in cls.authorized_missing_keys:
|
||||
if cls._keys_to_ignore_on_load_missing is not None:
|
||||
for pat in cls._keys_to_ignore_on_load_missing:
|
||||
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
||||
|
||||
if cls.authorized_unexpected_keys is not None:
|
||||
for pat in cls.authorized_unexpected_keys:
|
||||
if cls._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
|
||||
Reference in New Issue
Block a user