From 90178b0cefe94fef258a39cff5019b5ec150597b Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 13 Jul 2021 10:15:15 -0400 Subject: [PATCH] Add option to load a pretrained model with mismatched shapes (#12664) * Add option to load a pretrained model with mismatched shapes * Fail at loading when mismatched shapes in Flax * Fix tests * Update src/transformers/modeling_flax_utils.py Co-authored-by: Patrick von Platen * Address review comments Co-authored-by: Patrick von Platen --- src/transformers/modeling_flax_utils.py | 35 +++++++++++- src/transformers/modeling_tf_utils.py | 53 ++++++++++++++--- src/transformers/modeling_utils.py | 57 +++++++++++++++++-- .../models/auto/configuration_auto.py | 4 +- .../models/deberta/modeling_deberta.py | 23 -------- .../models/deberta_v2/modeling_deberta_v2.py | 23 -------- tests/test_modeling_common.py | 31 +++++++++- tests/test_modeling_flax_common.py | 38 ++++++++++++- tests/test_modeling_tf_common.py | 31 ++++++++++ 9 files changed, 228 insertions(+), 67 deletions(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 00ccdfcfb7..b93a97c628 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -199,6 +199,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`): Load the model weights from a PyTorch checkpoint save file (see docstring of ``pretrained_model_name_or_path`` argument). + ignore_mismatched_size (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to raise an error if some of the weights from the checkpoint do not have the same size + as the weights of the model (if for instance, you are instantiating a model with 10 labels from a + checkpoint with 3 labels). force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. @@ -242,6 +246,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", None) from_pt = kwargs.pop("from_pt", False) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) @@ -367,6 +372,22 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): missing_keys = model.required_params - set(state.keys()) unexpected_keys = set(state.keys()) - model.required_params + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not + # matching the weights in the model. + mismatched_keys = [] + for key in state.keys(): + if key in random_state and state[key].shape != random_state[key].shape: + if ignore_mismatched_sizes: + mismatched_keys.append((key, state[key].shape, random_state[key].shape)) + state[key] = random_state[key] + else: + raise ValueError( + f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " + f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. " + "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this " + "model." + ) + # add missing keys as random parameters for missing_key in missing_keys: state[missing_key] = random_state[missing_key] @@ -393,12 +414,24 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): f"and are newly initialized: {missing_keys}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) - else: + elif len(mismatched_keys) == 0: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" f"If your task is similar to the task the model of the checkpoint was trained on, " f"you can already use {model.__class__.__name__} for predictions without further training." ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " + f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n" + f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) # set correct parameters model.params = unflatten_dict(state) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index b2587353b6..ff84b80af0 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -450,7 +450,7 @@ def input_processing(func, config, input_ids, **kwargs): return output -def load_tf_weights(model, resolved_archive_file, _prefix=None): +def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): """ Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes. @@ -459,12 +459,16 @@ def load_tf_weights(model, resolved_archive_file, _prefix=None): The model to load the weights into. resolved_archive_file (:obj:`str`): The location of the H5 file. + ignore_mismatched_sizes (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to ignore weights with shapes that don't match between the checkpoint of the model. Returns: - Two lists, one for the missing layers, and another one for the unexpected layers. + Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the + mismatched layers. """ missing_layers = [] unexpected_layers = [] + mismatched_layers = [] # Read the H5 file with h5py.File(resolved_archive_file, "r") as f: @@ -533,9 +537,14 @@ def load_tf_weights(model, resolved_archive_file, _prefix=None): # If the two shapes are not compatible we raise an issue try: array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight)) - except AssertionError as e: - e.args += (K.int_shape(symbolic_weight), saved_weight_value.shape) - raise e + except ValueError as e: + if ignore_mismatched_sizes: + mismatched_layers.append( + (symbolic_weight_name, saved_weight_value.shape, K.int_shape(symbolic_weight)) + ) + continue + else: + raise e else: array = saved_weight_value @@ -549,7 +558,7 @@ def load_tf_weights(model, resolved_archive_file, _prefix=None): missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set)) unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names)) - return missing_layers, unexpected_layers + return missing_layers, unexpected_layers, mismatched_layers def init_copy_embeddings(old_embeddings, new_num_tokens): @@ -1123,6 +1132,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu from_pt: (:obj:`bool`, `optional`, defaults to :obj:`False`): Load the model weights from a PyTorch state_dict save file (see docstring of ``pretrained_model_name_or_path`` argument). + ignore_mismatched_size (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to raise an error if some of the weights from the checkpoint do not have the same size + as the weights of the model (if for instance, you are instantiating a model with 10 labels from a + checkpoint with 3 labels). cache_dir (:obj:`str`, `optional`): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. @@ -1186,6 +1199,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", None) from_pt = kwargs.pop("from_pt", False) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) @@ -1307,7 +1321,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu # 'by_name' allow us to do transfer learning by skipping/adding layers # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 try: - missing_keys, unexpected_keys = load_tf_weights(model, resolved_archive_file, load_weight_prefix) + missing_keys, unexpected_keys, mismatched_keys = load_tf_weights( + model, + resolved_archive_file, + ignore_mismatched_sizes=ignore_mismatched_sizes, + _prefix=load_weight_prefix, + ) except OSError: raise OSError( "Unable to load weights from h5 file. " @@ -1342,15 +1361,31 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu f"and are newly initialized: {missing_keys}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) - else: + elif len(mismatched_keys) == 0: logger.warning( f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" f"If your task is similar to the task the model of the checkpoint was trained on, " f"you can already use {model.__class__.__name__} for predictions without further training." ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " + f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n" + f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) if output_loading_info: - loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys} + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + } return model, loading_info diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f9af6129c3..207503ccf9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1037,6 +1037,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix from_flax (:obj:`bool`, `optional`, defaults to :obj:`False`): Load the model weights from a Flax checkpoint save file (see docstring of ``pretrained_model_name_or_path`` argument). + ignore_mismatched_size (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to raise an error if some of the weights from the checkpoint do not have the same size + as the weights of the model (if for instance, you are instantiating a model with 10 labels from a + checkpoint with 3 labels). force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. @@ -1120,6 +1124,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix cache_dir = kwargs.pop("cache_dir", None) from_tf = kwargs.pop("from_tf", False) from_flax = kwargs.pop("from_flax", False) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) @@ -1315,8 +1320,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) raise elif from_pt: - model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model( - model, state_dict, pretrained_model_name_or_path, _fast_init=_fast_init + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_state_dict_into_model( + model, + state_dict, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + _fast_init=_fast_init, ) # make sure token embedding weights are still tied if needed @@ -1329,6 +1338,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, "error_msgs": error_msgs, } return model, loading_info @@ -1336,7 +1346,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix return model @classmethod - def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init=True): + def _load_state_dict_into_model( + cls, model, state_dict, pretrained_model_name_or_path, ignore_mismatched_sizes=False, _fast_init=True + ): # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] @@ -1354,7 +1366,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix state_dict[new_key] = state_dict.pop(old_key) # Retrieve missing & unexpected_keys - expected_keys = list(model.state_dict().keys()) + model_state_dict = model.state_dict() + expected_keys = list(model_state_dict.keys()) loaded_keys = list(state_dict.keys()) prefix = model.base_model_prefix @@ -1374,6 +1387,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix missing_keys = list(set(expected_keys) - set(loaded_keys)) unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not + # matching the weights in the model. + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + if remove_prefix and checkpoint_key.startswith(prefix): + model_key = ".".join(checkpoint_key.split(".")[1:]) + elif add_prefix: + model_key = f"{prefix}.{checkpoint_key}" + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + # Some models may have keys that are not in the state by design, removing them before needlessly warning # the user. if cls._keys_to_ignore_on_load_missing is not None: @@ -1452,14 +1485,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix f"and are newly initialized: {missing_keys}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) - else: + elif len(mismatched_keys) == 0: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" f"If your task is similar to the task the model of the checkpoint was trained on, " f"you can already use {model.__class__.__name__} for predictions without further training." ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " + f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n" + f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) - return model, missing_keys, unexpected_keys, error_msgs + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): module_keys = set([".".join(key.split(".")[:-1]) for key in names]) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index d9ed2bec77..cf5fc55c05 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -186,7 +186,7 @@ CONFIG_MAPPING = OrderedDict( ("pegasus", PegasusConfig), ("marian", MarianConfig), ("mbart", MBartConfig), - ("megatron_bert", MegatronBertConfig), + ("megatron-bert", MegatronBertConfig), ("mpnet", MPNetConfig), ("bart", BartConfig), ("blenderbot", BlenderbotConfig), @@ -252,7 +252,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("blenderbot", "Blenderbot"), ("marian", "Marian"), ("mbart", "mBART"), - ("megatron_bert", "MegatronBert"), + ("megatron-bert", "MegatronBert"), ("bart", "BART"), ("reformer", "Reformer"), ("longformer", "Longformer"), diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 7bdc00ebd7..b1aaee0c26 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -760,10 +760,6 @@ class DebertaPreTrainedModel(PreTrainedModel): _keys_to_ignore_on_load_missing = ["position_ids"] _keys_to_ignore_on_load_unexpected = ["position_embeddings"] - def __init__(self, config): - super().__init__(config) - self._register_load_state_dict_pre_hook(self._pre_load_hook) - def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Linear): @@ -777,25 +773,6 @@ class DebertaPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - """ - Removes the classifier if it doesn't have the correct number of labels. - """ - self_state = self.state_dict() - if ( - ("classifier.weight" in self_state) - and ("classifier.weight" in state_dict) - and self_state["classifier.weight"].size() != state_dict["classifier.weight"].size() - ): - logger.warning( - f"The checkpoint classifier head has a shape {state_dict['classifier.weight'].size()} and this model " - f"classifier head has a shape {self_state['classifier.weight'].size()}. Ignoring the checkpoint " - f"weights. You should train your model on new data." - ) - del state_dict["classifier.weight"] - if "classifier.bias" in state_dict: - del state_dict["classifier.bias"] - DEBERTA_START_DOCSTRING = r""" The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index f186274380..eddd06129c 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -881,10 +881,6 @@ class DebertaV2PreTrainedModel(PreTrainedModel): _keys_to_ignore_on_load_missing = ["position_ids"] _keys_to_ignore_on_load_unexpected = ["position_embeddings"] - def __init__(self, config): - super().__init__(config) - self._register_load_state_dict_pre_hook(self._pre_load_hook) - def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Linear): @@ -898,25 +894,6 @@ class DebertaV2PreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - """ - Removes the classifier if it doesn't have the correct number of labels. - """ - self_state = self.state_dict() - if ( - ("classifier.weight" in self_state) - and ("classifier.weight" in state_dict) - and self_state["classifier.weight"].size() != state_dict["classifier.weight"].size() - ): - logger.warning( - f"The checkpoint classifier head has a shape {state_dict['classifier.weight'].size()} and this model " - f"classifier head has a shape {self_state['classifier.weight'].size()}. Ignoring the checkpoint " - f"weights. You should train your model on new data." - ) - del state_dict["classifier.weight"] - if "classifier.bias" in state_dict: - del state_dict["classifier.bias"] - DEBERTA_START_DOCSTRING = r""" The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4fd6217223..3d4c57c158 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -25,7 +25,7 @@ from typing import Dict, List, Tuple from huggingface_hub import HfApi from requests.exceptions import HTTPError -from transformers import AutoModel, is_torch_available, logging +from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available from transformers.models.auto import get_values from transformers.testing_utils import ( @@ -1532,6 +1532,35 @@ class ModelTesterMixin: loss.backward() + def test_load_with_mismatched_shapes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if model_class not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING): + continue + + with self.subTest(msg=f"Testing {model_class}"): + with tempfile.TemporaryDirectory() as tmp_dir: + model = model_class(config) + model.save_pretrained(tmp_dir) + + # Fails when we don't set ignore_mismatched_sizes=True + with self.assertRaises(RuntimeError) as e: + print(type(e)) + new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42) + + logger = logging.get_logger("transformers.modeling_utils") + with CaptureLogger(logger) as cl: + new_model = AutoModelForSequenceClassification.from_pretrained( + tmp_dir, num_labels=42, ignore_mismatched_sizes=True + ) + self.assertIn("the shapes did not match", cl.out) + + new_model.to(torch_device) + inputs = self._prepare_for_class(inputs_dict, model_class) + logits = new_model(**inputs).logits + self.assertEqual(logits.shape[1], 42) + global_rng = random.Random() diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 2a64be4a41..2646751459 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -24,17 +24,19 @@ import numpy as np import transformers from huggingface_hub import HfApi from requests.exceptions import HTTPError -from transformers import BertConfig, FlaxBertModel, is_flax_available, is_torch_available +from transformers import BertConfig, is_flax_available, is_torch_available from transformers.models.auto import get_values from transformers.testing_utils import ( ENDPOINT_STAGING, PASS, USER, + CaptureLogger, is_pt_flax_cross_test, is_staging_test, require_flax, slow, ) +from transformers.utils import logging if is_flax_available(): @@ -45,7 +47,13 @@ if is_flax_available(): import jaxlib.xla_extension as jax_xla from flax.core.frozen_dict import unfreeze from flax.traverse_util import flatten_dict - from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_MAPPING + from transformers import ( + FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, + FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + FLAX_MODEL_MAPPING, + FlaxAutoModelForSequenceClassification, + FlaxBertModel, + ) from transformers.modeling_flax_pytorch_utils import ( convert_pytorch_state_dict_to_flax, load_flax_weights_in_pytorch_model, @@ -516,6 +524,32 @@ class FlaxModelTesterMixin: [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], ) + def test_load_with_mismatched_shapes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if model_class not in get_values(FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING): + continue + + with self.subTest(msg=f"Testing {model_class}"): + with tempfile.TemporaryDirectory() as tmp_dir: + model = model_class(config) + model.save_pretrained(tmp_dir) + + # Fails when we don't set ignore_mismatched_sizes=True + with self.assertRaises(ValueError): + new_model = FlaxAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42) + + logger = logging.get_logger("transformers.modeling_flax_utils") + with CaptureLogger(logger) as cl: + new_model = FlaxAutoModelForSequenceClassification.from_pretrained( + tmp_dir, num_labels=42, ignore_mismatched_sizes=True + ) + self.assertIn("the shapes did not match", cl.out) + + logits = new_model(**inputs_dict)["logits"] + self.assertEqual(logits.shape[1], 42) + @require_flax @is_staging_test diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 3c907c7470..2b7b2c9143 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -32,6 +32,7 @@ from transformers.testing_utils import ( ENDPOINT_STAGING, PASS, USER, + CaptureLogger, _tf_gpu_memory_limit, is_pt_tf_cross_test, is_staging_test, @@ -40,6 +41,7 @@ from transformers.testing_utils import ( slow, tooslow, ) +from transformers.utils import logging if is_tf_available(): @@ -57,6 +59,7 @@ if is_tf_available(): TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, BertConfig, + TFAutoModelForSequenceClassification, TFBertModel, TFSharedEmbeddings, tf_top_k_top_p_filtering, @@ -1308,6 +1311,34 @@ class TFModelTesterMixin: attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1] self.assertEqual(sum([tf.reduce_sum(w).numpy() for w in attn_weights]), 0.0) + def test_load_with_mismatched_shapes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if model_class not in get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING): + continue + + with self.subTest(msg=f"Testing {model_class}"): + with tempfile.TemporaryDirectory() as tmp_dir: + model = model_class(config) + inputs = self._prepare_for_class(inputs_dict, model_class) + _ = model(**inputs) + model.save_pretrained(tmp_dir) + + # Fails when we don't set ignore_mismatched_sizes=True + with self.assertRaises(ValueError): + new_model = TFAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42) + + logger = logging.get_logger("transformers.modeling_tf_utils") + with CaptureLogger(logger) as cl: + new_model = TFAutoModelForSequenceClassification.from_pretrained( + tmp_dir, num_labels=42, ignore_mismatched_sizes=True + ) + self.assertIn("the shapes did not match", cl.out) + + logits = new_model(**inputs).logits + self.assertEqual(logits.shape[1], 42) + def _generate_random_bad_tokens(self, num_bad_tokens, model): # special tokens cannot be bad tokens special_tokens = []