Add option to load a pretrained model with mismatched shapes (#12664)
* Add option to load a pretrained model with mismatched shapes * Fail at loading when mismatched shapes in Flax * Fix tests * Update src/transformers/modeling_flax_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Address review comments Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -199,6 +199,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Load the model weights from a PyTorch checkpoint save file (see docstring of
|
Load the model weights from a PyTorch checkpoint save file (see docstring of
|
||||||
``pretrained_model_name_or_path`` argument).
|
``pretrained_model_name_or_path`` argument).
|
||||||
|
ignore_mismatched_size (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
|
||||||
|
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
|
||||||
|
checkpoint with 3 labels).
|
||||||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||||
cached versions if they exist.
|
cached versions if they exist.
|
||||||
@@ -242,6 +246,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
config = kwargs.pop("config", None)
|
config = kwargs.pop("config", None)
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
from_pt = kwargs.pop("from_pt", False)
|
from_pt = kwargs.pop("from_pt", False)
|
||||||
|
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
@@ -367,6 +372,22 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
missing_keys = model.required_params - set(state.keys())
|
missing_keys = model.required_params - set(state.keys())
|
||||||
unexpected_keys = set(state.keys()) - model.required_params
|
unexpected_keys = set(state.keys()) - model.required_params
|
||||||
|
|
||||||
|
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||||
|
# matching the weights in the model.
|
||||||
|
mismatched_keys = []
|
||||||
|
for key in state.keys():
|
||||||
|
if key in random_state and state[key].shape != random_state[key].shape:
|
||||||
|
if ignore_mismatched_sizes:
|
||||||
|
mismatched_keys.append((key, state[key].shape, random_state[key].shape))
|
||||||
|
state[key] = random_state[key]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
|
||||||
|
f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
|
||||||
|
"Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
|
||||||
|
"model."
|
||||||
|
)
|
||||||
|
|
||||||
# add missing keys as random parameters
|
# add missing keys as random parameters
|
||||||
for missing_key in missing_keys:
|
for missing_key in missing_keys:
|
||||||
state[missing_key] = random_state[missing_key]
|
state[missing_key] = random_state[missing_key]
|
||||||
@@ -393,12 +414,24 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
f"and are newly initialized: {missing_keys}\n"
|
f"and are newly initialized: {missing_keys}\n"
|
||||||
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||||
)
|
)
|
||||||
else:
|
elif len(mismatched_keys) == 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
||||||
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
||||||
f"you can already use {model.__class__.__name__} for predictions without further training."
|
f"you can already use {model.__class__.__name__} for predictions without further training."
|
||||||
)
|
)
|
||||||
|
if len(mismatched_keys) > 0:
|
||||||
|
mismatched_warning = "\n".join(
|
||||||
|
[
|
||||||
|
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
||||||
|
for key, shape1, shape2 in mismatched_keys
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
||||||
|
f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n"
|
||||||
|
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||||
|
)
|
||||||
|
|
||||||
# set correct parameters
|
# set correct parameters
|
||||||
model.params = unflatten_dict(state)
|
model.params = unflatten_dict(state)
|
||||||
|
|||||||
@@ -450,7 +450,7 @@ def input_processing(func, config, input_ids, **kwargs):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def load_tf_weights(model, resolved_archive_file, _prefix=None):
|
def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
|
||||||
"""
|
"""
|
||||||
Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes.
|
Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes.
|
||||||
|
|
||||||
@@ -459,12 +459,16 @@ def load_tf_weights(model, resolved_archive_file, _prefix=None):
|
|||||||
The model to load the weights into.
|
The model to load the weights into.
|
||||||
resolved_archive_file (:obj:`str`):
|
resolved_archive_file (:obj:`str`):
|
||||||
The location of the H5 file.
|
The location of the H5 file.
|
||||||
|
ignore_mismatched_sizes (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to ignore weights with shapes that don't match between the checkpoint of the model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Two lists, one for the missing layers, and another one for the unexpected layers.
|
Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
|
||||||
|
mismatched layers.
|
||||||
"""
|
"""
|
||||||
missing_layers = []
|
missing_layers = []
|
||||||
unexpected_layers = []
|
unexpected_layers = []
|
||||||
|
mismatched_layers = []
|
||||||
|
|
||||||
# Read the H5 file
|
# Read the H5 file
|
||||||
with h5py.File(resolved_archive_file, "r") as f:
|
with h5py.File(resolved_archive_file, "r") as f:
|
||||||
@@ -533,9 +537,14 @@ def load_tf_weights(model, resolved_archive_file, _prefix=None):
|
|||||||
# If the two shapes are not compatible we raise an issue
|
# If the two shapes are not compatible we raise an issue
|
||||||
try:
|
try:
|
||||||
array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
|
array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
|
||||||
except AssertionError as e:
|
except ValueError as e:
|
||||||
e.args += (K.int_shape(symbolic_weight), saved_weight_value.shape)
|
if ignore_mismatched_sizes:
|
||||||
raise e
|
mismatched_layers.append(
|
||||||
|
(symbolic_weight_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
else:
|
else:
|
||||||
array = saved_weight_value
|
array = saved_weight_value
|
||||||
|
|
||||||
@@ -549,7 +558,7 @@ def load_tf_weights(model, resolved_archive_file, _prefix=None):
|
|||||||
missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
|
missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
|
||||||
unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))
|
unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))
|
||||||
|
|
||||||
return missing_layers, unexpected_layers
|
return missing_layers, unexpected_layers, mismatched_layers
|
||||||
|
|
||||||
|
|
||||||
def init_copy_embeddings(old_embeddings, new_num_tokens):
|
def init_copy_embeddings(old_embeddings, new_num_tokens):
|
||||||
@@ -1123,6 +1132,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
from_pt: (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
from_pt: (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Load the model weights from a PyTorch state_dict save file (see docstring of
|
Load the model weights from a PyTorch state_dict save file (see docstring of
|
||||||
``pretrained_model_name_or_path`` argument).
|
``pretrained_model_name_or_path`` argument).
|
||||||
|
ignore_mismatched_size (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
|
||||||
|
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
|
||||||
|
checkpoint with 3 labels).
|
||||||
cache_dir (:obj:`str`, `optional`):
|
cache_dir (:obj:`str`, `optional`):
|
||||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||||
standard cache should not be used.
|
standard cache should not be used.
|
||||||
@@ -1186,6 +1199,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
config = kwargs.pop("config", None)
|
config = kwargs.pop("config", None)
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
from_pt = kwargs.pop("from_pt", False)
|
from_pt = kwargs.pop("from_pt", False)
|
||||||
|
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
@@ -1307,7 +1321,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||||
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
||||||
try:
|
try:
|
||||||
missing_keys, unexpected_keys = load_tf_weights(model, resolved_archive_file, load_weight_prefix)
|
missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
|
||||||
|
model,
|
||||||
|
resolved_archive_file,
|
||||||
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||||
|
_prefix=load_weight_prefix,
|
||||||
|
)
|
||||||
except OSError:
|
except OSError:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
"Unable to load weights from h5 file. "
|
"Unable to load weights from h5 file. "
|
||||||
@@ -1342,15 +1361,31 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
f"and are newly initialized: {missing_keys}\n"
|
f"and are newly initialized: {missing_keys}\n"
|
||||||
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||||
)
|
)
|
||||||
else:
|
elif len(mismatched_keys) == 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
||||||
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
||||||
f"you can already use {model.__class__.__name__} for predictions without further training."
|
f"you can already use {model.__class__.__name__} for predictions without further training."
|
||||||
)
|
)
|
||||||
|
if len(mismatched_keys) > 0:
|
||||||
|
mismatched_warning = "\n".join(
|
||||||
|
[
|
||||||
|
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
||||||
|
for key, shape1, shape2 in mismatched_keys
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
||||||
|
f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n"
|
||||||
|
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||||
|
)
|
||||||
|
|
||||||
if output_loading_info:
|
if output_loading_info:
|
||||||
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
|
loading_info = {
|
||||||
|
"missing_keys": missing_keys,
|
||||||
|
"unexpected_keys": unexpected_keys,
|
||||||
|
"mismatched_keys": mismatched_keys,
|
||||||
|
}
|
||||||
|
|
||||||
return model, loading_info
|
return model, loading_info
|
||||||
|
|
||||||
|
|||||||
@@ -1037,6 +1037,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
from_flax (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
from_flax (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Load the model weights from a Flax checkpoint save file (see docstring of
|
Load the model weights from a Flax checkpoint save file (see docstring of
|
||||||
``pretrained_model_name_or_path`` argument).
|
``pretrained_model_name_or_path`` argument).
|
||||||
|
ignore_mismatched_size (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
|
||||||
|
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
|
||||||
|
checkpoint with 3 labels).
|
||||||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||||
cached versions if they exist.
|
cached versions if they exist.
|
||||||
@@ -1120,6 +1124,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
cache_dir = kwargs.pop("cache_dir", None)
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
from_tf = kwargs.pop("from_tf", False)
|
from_tf = kwargs.pop("from_tf", False)
|
||||||
from_flax = kwargs.pop("from_flax", False)
|
from_flax = kwargs.pop("from_flax", False)
|
||||||
|
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
@@ -1315,8 +1320,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
elif from_pt:
|
elif from_pt:
|
||||||
model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model(
|
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_state_dict_into_model(
|
||||||
model, state_dict, pretrained_model_name_or_path, _fast_init=_fast_init
|
model,
|
||||||
|
state_dict,
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||||
|
_fast_init=_fast_init,
|
||||||
)
|
)
|
||||||
|
|
||||||
# make sure token embedding weights are still tied if needed
|
# make sure token embedding weights are still tied if needed
|
||||||
@@ -1329,6 +1338,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
loading_info = {
|
loading_info = {
|
||||||
"missing_keys": missing_keys,
|
"missing_keys": missing_keys,
|
||||||
"unexpected_keys": unexpected_keys,
|
"unexpected_keys": unexpected_keys,
|
||||||
|
"mismatched_keys": mismatched_keys,
|
||||||
"error_msgs": error_msgs,
|
"error_msgs": error_msgs,
|
||||||
}
|
}
|
||||||
return model, loading_info
|
return model, loading_info
|
||||||
@@ -1336,7 +1346,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init=True):
|
def _load_state_dict_into_model(
|
||||||
|
cls, model, state_dict, pretrained_model_name_or_path, ignore_mismatched_sizes=False, _fast_init=True
|
||||||
|
):
|
||||||
|
|
||||||
# Convert old format to new format if needed from a PyTorch state_dict
|
# Convert old format to new format if needed from a PyTorch state_dict
|
||||||
old_keys = []
|
old_keys = []
|
||||||
@@ -1354,7 +1366,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
state_dict[new_key] = state_dict.pop(old_key)
|
state_dict[new_key] = state_dict.pop(old_key)
|
||||||
|
|
||||||
# Retrieve missing & unexpected_keys
|
# Retrieve missing & unexpected_keys
|
||||||
expected_keys = list(model.state_dict().keys())
|
model_state_dict = model.state_dict()
|
||||||
|
expected_keys = list(model_state_dict.keys())
|
||||||
loaded_keys = list(state_dict.keys())
|
loaded_keys = list(state_dict.keys())
|
||||||
prefix = model.base_model_prefix
|
prefix = model.base_model_prefix
|
||||||
|
|
||||||
@@ -1374,6 +1387,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
||||||
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
||||||
|
|
||||||
|
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||||
|
# matching the weights in the model.
|
||||||
|
mismatched_keys = []
|
||||||
|
if ignore_mismatched_sizes:
|
||||||
|
for checkpoint_key in loaded_keys:
|
||||||
|
model_key = checkpoint_key
|
||||||
|
if remove_prefix and checkpoint_key.startswith(prefix):
|
||||||
|
model_key = ".".join(checkpoint_key.split(".")[1:])
|
||||||
|
elif add_prefix:
|
||||||
|
model_key = f"{prefix}.{checkpoint_key}"
|
||||||
|
|
||||||
|
if (
|
||||||
|
model_key in model_state_dict
|
||||||
|
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
||||||
|
):
|
||||||
|
mismatched_keys.append(
|
||||||
|
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||||
|
)
|
||||||
|
del state_dict[checkpoint_key]
|
||||||
|
|
||||||
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
||||||
# the user.
|
# the user.
|
||||||
if cls._keys_to_ignore_on_load_missing is not None:
|
if cls._keys_to_ignore_on_load_missing is not None:
|
||||||
@@ -1452,14 +1485,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
f"and are newly initialized: {missing_keys}\n"
|
f"and are newly initialized: {missing_keys}\n"
|
||||||
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||||
)
|
)
|
||||||
else:
|
elif len(mismatched_keys) == 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
||||||
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
||||||
f"you can already use {model.__class__.__name__} for predictions without further training."
|
f"you can already use {model.__class__.__name__} for predictions without further training."
|
||||||
)
|
)
|
||||||
|
if len(mismatched_keys) > 0:
|
||||||
|
mismatched_warning = "\n".join(
|
||||||
|
[
|
||||||
|
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
||||||
|
for key, shape1, shape2 in mismatched_keys
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
||||||
|
f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n"
|
||||||
|
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||||
|
)
|
||||||
|
|
||||||
return model, missing_keys, unexpected_keys, error_msgs
|
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
||||||
|
|
||||||
def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
|
def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
|
||||||
module_keys = set([".".join(key.split(".")[:-1]) for key in names])
|
module_keys = set([".".join(key.split(".")[:-1]) for key in names])
|
||||||
|
|||||||
@@ -186,7 +186,7 @@ CONFIG_MAPPING = OrderedDict(
|
|||||||
("pegasus", PegasusConfig),
|
("pegasus", PegasusConfig),
|
||||||
("marian", MarianConfig),
|
("marian", MarianConfig),
|
||||||
("mbart", MBartConfig),
|
("mbart", MBartConfig),
|
||||||
("megatron_bert", MegatronBertConfig),
|
("megatron-bert", MegatronBertConfig),
|
||||||
("mpnet", MPNetConfig),
|
("mpnet", MPNetConfig),
|
||||||
("bart", BartConfig),
|
("bart", BartConfig),
|
||||||
("blenderbot", BlenderbotConfig),
|
("blenderbot", BlenderbotConfig),
|
||||||
@@ -252,7 +252,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("blenderbot", "Blenderbot"),
|
("blenderbot", "Blenderbot"),
|
||||||
("marian", "Marian"),
|
("marian", "Marian"),
|
||||||
("mbart", "mBART"),
|
("mbart", "mBART"),
|
||||||
("megatron_bert", "MegatronBert"),
|
("megatron-bert", "MegatronBert"),
|
||||||
("bart", "BART"),
|
("bart", "BART"),
|
||||||
("reformer", "Reformer"),
|
("reformer", "Reformer"),
|
||||||
("longformer", "Longformer"),
|
("longformer", "Longformer"),
|
||||||
|
|||||||
@@ -760,10 +760,6 @@ class DebertaPreTrainedModel(PreTrainedModel):
|
|||||||
_keys_to_ignore_on_load_missing = ["position_ids"]
|
_keys_to_ignore_on_load_missing = ["position_ids"]
|
||||||
_keys_to_ignore_on_load_unexpected = ["position_embeddings"]
|
_keys_to_ignore_on_load_unexpected = ["position_embeddings"]
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__(config)
|
|
||||||
self._register_load_state_dict_pre_hook(self._pre_load_hook)
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights."""
|
"""Initialize the weights."""
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
@@ -777,25 +773,6 @@ class DebertaPreTrainedModel(PreTrainedModel):
|
|||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
|
||||||
"""
|
|
||||||
Removes the classifier if it doesn't have the correct number of labels.
|
|
||||||
"""
|
|
||||||
self_state = self.state_dict()
|
|
||||||
if (
|
|
||||||
("classifier.weight" in self_state)
|
|
||||||
and ("classifier.weight" in state_dict)
|
|
||||||
and self_state["classifier.weight"].size() != state_dict["classifier.weight"].size()
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
f"The checkpoint classifier head has a shape {state_dict['classifier.weight'].size()} and this model "
|
|
||||||
f"classifier head has a shape {self_state['classifier.weight'].size()}. Ignoring the checkpoint "
|
|
||||||
f"weights. You should train your model on new data."
|
|
||||||
)
|
|
||||||
del state_dict["classifier.weight"]
|
|
||||||
if "classifier.bias" in state_dict:
|
|
||||||
del state_dict["classifier.bias"]
|
|
||||||
|
|
||||||
|
|
||||||
DEBERTA_START_DOCSTRING = r"""
|
DEBERTA_START_DOCSTRING = r"""
|
||||||
The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention
|
The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention
|
||||||
|
|||||||
@@ -881,10 +881,6 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
|
|||||||
_keys_to_ignore_on_load_missing = ["position_ids"]
|
_keys_to_ignore_on_load_missing = ["position_ids"]
|
||||||
_keys_to_ignore_on_load_unexpected = ["position_embeddings"]
|
_keys_to_ignore_on_load_unexpected = ["position_embeddings"]
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__(config)
|
|
||||||
self._register_load_state_dict_pre_hook(self._pre_load_hook)
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights."""
|
"""Initialize the weights."""
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
@@ -898,25 +894,6 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
|
|||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
|
||||||
"""
|
|
||||||
Removes the classifier if it doesn't have the correct number of labels.
|
|
||||||
"""
|
|
||||||
self_state = self.state_dict()
|
|
||||||
if (
|
|
||||||
("classifier.weight" in self_state)
|
|
||||||
and ("classifier.weight" in state_dict)
|
|
||||||
and self_state["classifier.weight"].size() != state_dict["classifier.weight"].size()
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
f"The checkpoint classifier head has a shape {state_dict['classifier.weight'].size()} and this model "
|
|
||||||
f"classifier head has a shape {self_state['classifier.weight'].size()}. Ignoring the checkpoint "
|
|
||||||
f"weights. You should train your model on new data."
|
|
||||||
)
|
|
||||||
del state_dict["classifier.weight"]
|
|
||||||
if "classifier.bias" in state_dict:
|
|
||||||
del state_dict["classifier.bias"]
|
|
||||||
|
|
||||||
|
|
||||||
DEBERTA_START_DOCSTRING = r"""
|
DEBERTA_START_DOCSTRING = r"""
|
||||||
The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention
|
The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from typing import Dict, List, Tuple
|
|||||||
|
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers import AutoModel, is_torch_available, logging
|
from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging
|
||||||
from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available
|
from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available
|
||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
@@ -1532,6 +1532,35 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
def test_load_with_mismatched_shapes(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if model_class not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
|
||||||
|
continue
|
||||||
|
|
||||||
|
with self.subTest(msg=f"Testing {model_class}"):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model = model_class(config)
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Fails when we don't set ignore_mismatched_sizes=True
|
||||||
|
with self.assertRaises(RuntimeError) as e:
|
||||||
|
print(type(e))
|
||||||
|
new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
|
||||||
|
|
||||||
|
logger = logging.get_logger("transformers.modeling_utils")
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
new_model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
|
||||||
|
)
|
||||||
|
self.assertIn("the shapes did not match", cl.out)
|
||||||
|
|
||||||
|
new_model.to(torch_device)
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
logits = new_model(**inputs).logits
|
||||||
|
self.assertEqual(logits.shape[1], 42)
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
|
|
||||||
|
|||||||
@@ -24,17 +24,19 @@ import numpy as np
|
|||||||
import transformers
|
import transformers
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers import BertConfig, FlaxBertModel, is_flax_available, is_torch_available
|
from transformers import BertConfig, is_flax_available, is_torch_available
|
||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
ENDPOINT_STAGING,
|
ENDPOINT_STAGING,
|
||||||
PASS,
|
PASS,
|
||||||
USER,
|
USER,
|
||||||
|
CaptureLogger,
|
||||||
is_pt_flax_cross_test,
|
is_pt_flax_cross_test,
|
||||||
is_staging_test,
|
is_staging_test,
|
||||||
require_flax,
|
require_flax,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
@@ -45,7 +47,13 @@ if is_flax_available():
|
|||||||
import jaxlib.xla_extension as jax_xla
|
import jaxlib.xla_extension as jax_xla
|
||||||
from flax.core.frozen_dict import unfreeze
|
from flax.core.frozen_dict import unfreeze
|
||||||
from flax.traverse_util import flatten_dict
|
from flax.traverse_util import flatten_dict
|
||||||
from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_MAPPING
|
from transformers import (
|
||||||
|
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
|
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
|
FLAX_MODEL_MAPPING,
|
||||||
|
FlaxAutoModelForSequenceClassification,
|
||||||
|
FlaxBertModel,
|
||||||
|
)
|
||||||
from transformers.modeling_flax_pytorch_utils import (
|
from transformers.modeling_flax_pytorch_utils import (
|
||||||
convert_pytorch_state_dict_to_flax,
|
convert_pytorch_state_dict_to_flax,
|
||||||
load_flax_weights_in_pytorch_model,
|
load_flax_weights_in_pytorch_model,
|
||||||
@@ -516,6 +524,32 @@ class FlaxModelTesterMixin:
|
|||||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_load_with_mismatched_shapes(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if model_class not in get_values(FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
|
||||||
|
continue
|
||||||
|
|
||||||
|
with self.subTest(msg=f"Testing {model_class}"):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model = model_class(config)
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Fails when we don't set ignore_mismatched_sizes=True
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
new_model = FlaxAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
|
||||||
|
|
||||||
|
logger = logging.get_logger("transformers.modeling_flax_utils")
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
new_model = FlaxAutoModelForSequenceClassification.from_pretrained(
|
||||||
|
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
|
||||||
|
)
|
||||||
|
self.assertIn("the shapes did not match", cl.out)
|
||||||
|
|
||||||
|
logits = new_model(**inputs_dict)["logits"]
|
||||||
|
self.assertEqual(logits.shape[1], 42)
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
@require_flax
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from transformers.testing_utils import (
|
|||||||
ENDPOINT_STAGING,
|
ENDPOINT_STAGING,
|
||||||
PASS,
|
PASS,
|
||||||
USER,
|
USER,
|
||||||
|
CaptureLogger,
|
||||||
_tf_gpu_memory_limit,
|
_tf_gpu_memory_limit,
|
||||||
is_pt_tf_cross_test,
|
is_pt_tf_cross_test,
|
||||||
is_staging_test,
|
is_staging_test,
|
||||||
@@ -40,6 +41,7 @@ from transformers.testing_utils import (
|
|||||||
slow,
|
slow,
|
||||||
tooslow,
|
tooslow,
|
||||||
)
|
)
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
@@ -57,6 +59,7 @@ if is_tf_available():
|
|||||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
|
TFAutoModelForSequenceClassification,
|
||||||
TFBertModel,
|
TFBertModel,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
tf_top_k_top_p_filtering,
|
tf_top_k_top_p_filtering,
|
||||||
@@ -1308,6 +1311,34 @@ class TFModelTesterMixin:
|
|||||||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||||
self.assertEqual(sum([tf.reduce_sum(w).numpy() for w in attn_weights]), 0.0)
|
self.assertEqual(sum([tf.reduce_sum(w).numpy() for w in attn_weights]), 0.0)
|
||||||
|
|
||||||
|
def test_load_with_mismatched_shapes(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if model_class not in get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
|
||||||
|
continue
|
||||||
|
|
||||||
|
with self.subTest(msg=f"Testing {model_class}"):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model = model_class(config)
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
_ = model(**inputs)
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Fails when we don't set ignore_mismatched_sizes=True
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
new_model = TFAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
|
||||||
|
|
||||||
|
logger = logging.get_logger("transformers.modeling_tf_utils")
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
new_model = TFAutoModelForSequenceClassification.from_pretrained(
|
||||||
|
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
|
||||||
|
)
|
||||||
|
self.assertIn("the shapes did not match", cl.out)
|
||||||
|
|
||||||
|
logits = new_model(**inputs).logits
|
||||||
|
self.assertEqual(logits.shape[1], 42)
|
||||||
|
|
||||||
def _generate_random_bad_tokens(self, num_bad_tokens, model):
|
def _generate_random_bad_tokens(self, num_bad_tokens, model):
|
||||||
# special tokens cannot be bad tokens
|
# special tokens cannot be bad tokens
|
||||||
special_tokens = []
|
special_tokens = []
|
||||||
|
|||||||
Reference in New Issue
Block a user