From 234cfefbb083d2614a55f6093b0badfb2efc3b45 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Fri, 22 Oct 2021 00:31:29 +0800 Subject: [PATCH] Fix ignore_mismatched_sizes (#14085) * Fix * Style * Name * Fix tests * Style * Remove embed sizes checking * Disable some tests * Fix * Apply suggestion --- src/transformers/modeling_utils.py | 6 +++--- tests/test_modeling_canine.py | 1 + tests/test_modeling_common.py | 22 +++++++++++++++++++--- tests/test_modeling_flax_big_bird.py | 1 + tests/test_modeling_flax_common.py | 17 +++++++++++++++++ tests/test_modeling_layoutlmv2.py | 1 + tests/test_modeling_tf_common.py | 20 ++++++++++++++++++++ tests/test_modeling_tf_transfo_xl.py | 1 + tests/test_modeling_transfo_xl.py | 1 + 9 files changed, 64 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 14e8546a3c..9d825b5b20 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1512,10 +1512,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if ignore_mismatched_sizes: for checkpoint_key in loaded_keys: model_key = checkpoint_key - if remove_prefix and checkpoint_key.startswith(prefix): - model_key = ".".join(checkpoint_key.split(".")[1:]) - elif add_prefix: + if remove_prefix: model_key = f"{prefix}.{checkpoint_key}" + elif add_prefix: + model_key = ".".join(checkpoint_key.split(".")[1:]) if ( model_key in model_state_dict diff --git a/tests/test_modeling_canine.py b/tests/test_modeling_canine.py index c6a30e855b..888e1d33e6 100644 --- a/tests/test_modeling_canine.py +++ b/tests/test_modeling_canine.py @@ -220,6 +220,7 @@ class CanineModelTest(ModelTesterMixin, unittest.TestCase): ) test_torchscript = False + test_mismatched_shapes = False test_resize_embeddings = False test_pruning = False diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f946d59017..889c2e830f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -98,6 +98,7 @@ class ModelTesterMixin: test_resize_embeddings = True test_resize_position_embeddings = False test_head_masking = True + test_mismatched_shapes = True test_missing_keys = True test_model_parallel = False is_encoder_decoder = False @@ -1638,6 +1639,8 @@ class ModelTesterMixin: loss.backward() def test_load_with_mismatched_shapes(self): + if not self.test_mismatched_shapes: + return config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: @@ -1650,22 +1653,35 @@ class ModelTesterMixin: model.save_pretrained(tmp_dir) # Fails when we don't set ignore_mismatched_sizes=True - with self.assertRaises(RuntimeError) as e: - print(type(e)) + with self.assertRaises(RuntimeError): new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42) + with self.assertRaises(RuntimeError): + new_model_without_prefix = AutoModel.from_pretrained(tmp_dir, vocab_size=10) logger = logging.get_logger("transformers.modeling_utils") + with CaptureLogger(logger) as cl: new_model = AutoModelForSequenceClassification.from_pretrained( tmp_dir, num_labels=42, ignore_mismatched_sizes=True ) self.assertIn("the shapes did not match", cl.out) - new_model.to(torch_device) inputs = self._prepare_for_class(inputs_dict, model_class) logits = new_model(**inputs).logits self.assertEqual(logits.shape[1], 42) + with CaptureLogger(logger) as cl: + new_model_without_prefix = AutoModel.from_pretrained( + tmp_dir, vocab_size=10, ignore_mismatched_sizes=True + ) + self.assertIn("the shapes did not match", cl.out) + input_ids = ids_tensor((2, 8), 10) + new_model_without_prefix.to(torch_device) + if self.is_encoder_decoder: + new_model_without_prefix(input_ids, decoder_input_ids=input_ids) + else: + new_model_without_prefix(input_ids) + global_rng = random.Random() diff --git a/tests/test_modeling_flax_big_bird.py b/tests/test_modeling_flax_big_bird.py index c8afce2685..8af5949a1a 100644 --- a/tests/test_modeling_flax_big_bird.py +++ b/tests/test_modeling_flax_big_bird.py @@ -149,6 +149,7 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase): ) test_attn_probs = False + test_mismatched_shapes = False def setUp(self): self.model_tester = FlaxBigBirdModelTester(self) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 612cb5d2aa..4e5acbfa65 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -49,6 +49,7 @@ if is_flax_available(): FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, FLAX_MODEL_MAPPING, + FlaxAutoModel, FlaxAutoModelForSequenceClassification, FlaxBertModel, ) @@ -116,6 +117,7 @@ def random_attention_mask(shape, rng=None): class FlaxModelTesterMixin: model_tester = None all_model_classes = () + test_mismatched_shapes = True is_encoder_decoder = False def _prepare_for_class(self, inputs_dict, model_class): @@ -579,6 +581,8 @@ class FlaxModelTesterMixin: ) def test_load_with_mismatched_shapes(self): + if not self.test_mismatched_shapes: + return config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: @@ -593,6 +597,8 @@ class FlaxModelTesterMixin: # Fails when we don't set ignore_mismatched_sizes=True with self.assertRaises(ValueError): new_model = FlaxAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42) + with self.assertRaises(ValueError): + new_model_without_prefix = FlaxAutoModel.from_pretrained(tmp_dir, vocab_size=10) logger = logging.get_logger("transformers.modeling_flax_utils") with CaptureLogger(logger) as cl: @@ -604,6 +610,17 @@ class FlaxModelTesterMixin: logits = new_model(**inputs_dict)["logits"] self.assertEqual(logits.shape[1], 42) + with CaptureLogger(logger) as cl: + new_model_without_prefix = FlaxAutoModel.from_pretrained( + tmp_dir, vocab_size=10, ignore_mismatched_sizes=True + ) + self.assertIn("the shapes did not match", cl.out) + input_ids = ids_tensor((2, 8), 10) + if self.is_encoder_decoder: + new_model_without_prefix(input_ids, decoder_input_ids=input_ids) + else: + new_model_without_prefix(input_ids) + @require_flax @is_staging_test diff --git a/tests/test_modeling_layoutlmv2.py b/tests/test_modeling_layoutlmv2.py index 4ce25d5bd1..cbcb6f5139 100644 --- a/tests/test_modeling_layoutlmv2.py +++ b/tests/test_modeling_layoutlmv2.py @@ -260,6 +260,7 @@ class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase): test_pruning = False test_torchscript = False + test_mismatched_shapes = False all_model_classes = ( ( diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 49cd6e4fb1..8a6897b038 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -59,6 +59,7 @@ if is_tf_available(): TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, BertConfig, + TFAutoModel, TFAutoModelForSequenceClassification, TFBertModel, TFSharedEmbeddings, @@ -104,6 +105,7 @@ class TFModelTesterMixin: model_tester = None all_model_classes = () all_generative_model_classes = () + test_mismatched_shapes = True test_resize_embeddings = True test_head_masking = True is_encoder_decoder = False @@ -1312,6 +1314,8 @@ class TFModelTesterMixin: self.assertEqual(sum([tf.reduce_sum(w).numpy() for w in attn_weights]), 0.0) def test_load_with_mismatched_shapes(self): + if not self.test_mismatched_shapes: + return config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: @@ -1328,6 +1332,8 @@ class TFModelTesterMixin: # Fails when we don't set ignore_mismatched_sizes=True with self.assertRaises(ValueError): new_model = TFAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42) + with self.assertRaises(ValueError): + new_model_without_prefix = TFAutoModel.from_pretrained(tmp_dir, vocab_size=10) logger = logging.get_logger("transformers.modeling_tf_utils") with CaptureLogger(logger) as cl: @@ -1339,6 +1345,20 @@ class TFModelTesterMixin: logits = new_model(**inputs).logits self.assertEqual(logits.shape[1], 42) + with CaptureLogger(logger) as cl: + new_model_without_prefix = TFAutoModel.from_pretrained( + tmp_dir, vocab_size=10, ignore_mismatched_sizes=True + ) + self.assertIn("the shapes did not match", cl.out) + + # Although Tf models always have a prefix pointing to `MainLayer`, + # we still add this "without prefix" test to keep a consistency between tf and pt tests. + input_ids = ids_tensor((2, 8), 10) + if self.is_encoder_decoder: + new_model_without_prefix(input_ids, decoder_input_ids=input_ids) + else: + new_model_without_prefix(input_ids) + def _generate_random_bad_tokens(self, num_bad_tokens, model): # special tokens cannot be bad tokens special_tokens = [] diff --git a/tests/test_modeling_tf_transfo_xl.py b/tests/test_modeling_tf_transfo_xl.py index da465f9d44..1b38c92585 100644 --- a/tests/test_modeling_tf_transfo_xl.py +++ b/tests/test_modeling_tf_transfo_xl.py @@ -165,6 +165,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase): test_resize_embeddings = False test_head_masking = False test_onnx = False + test_mismatched_shapes = False def setUp(self): self.model_tester = TFTransfoXLModelTester(self) diff --git a/tests/test_modeling_transfo_xl.py b/tests/test_modeling_transfo_xl.py index 51ec77d24d..4885e97329 100644 --- a/tests/test_modeling_transfo_xl.py +++ b/tests/test_modeling_transfo_xl.py @@ -180,6 +180,7 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC test_pruning = False test_torchscript = False test_resize_embeddings = True + test_mismatched_shapes = False def check_cutoffs_and_n_token( self, copied_cutoffs, layer, model_embed, model, model_class, resized_value, vocab_size