From d4c834d2e0465ef268a1f00ee6666982754bcf2d Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 4 Aug 2021 11:48:39 +0200 Subject: [PATCH] Fix from_pretrained with corrupted state_dict (#12939) * Fix from_pretrained with corrupted state_dict * Adapt test * Use better checkpoint * Style * Clean up --- src/transformers/modeling_utils.py | 6 ++++++ tests/test_benchmark.py | 2 +- tests/test_benchmark_tf.py | 2 +- tests/test_pipelines_zero_shot.py | 4 +--- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b34637b021..5d4d6cfd0f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1409,6 +1409,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix add_prefix = has_prefix_module and not expects_prefix_module if remove_prefix: + expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(prefix)] expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys] elif add_prefix: expected_keys = [".".join([prefix, s]) for s in expected_keys] @@ -1490,6 +1491,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix start_prefix = cls.base_model_prefix + "." if hasattr(model, cls.base_model_prefix) and not has_prefix_module: model_to_load = getattr(model, cls.base_model_prefix) + if any(key in expected_keys_not_prefixed for key in loaded_keys): + raise ValueError( + "The state dictionary of the model you are training to load is corrupted. Are you sure it was " + "properly saved?" + ) load(model_to_load, prefix=start_prefix) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 359efba8bb..e4444ec2c4 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -49,7 +49,7 @@ class BenchmarkTest(unittest.TestCase): self.check_results_dict_not_empty(results.memory_inference_result) def test_inference_no_configs_only_pretrain(self): - MODEL_ID = "sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english" + MODEL_ID = "sgugger/tiny-distilbert-classification" benchmark_args = PyTorchBenchmarkArguments( models=[MODEL_ID], training=False, diff --git a/tests/test_benchmark_tf.py b/tests/test_benchmark_tf.py index 2bd72e09d0..2cea8e4c68 100644 --- a/tests/test_benchmark_tf.py +++ b/tests/test_benchmark_tf.py @@ -52,7 +52,7 @@ class TFBenchmarkTest(unittest.TestCase): self.check_results_dict_not_empty(results.memory_inference_result) def test_inference_no_configs_only_pretrain(self): - MODEL_ID = "sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english" + MODEL_ID = "sgugger/tiny-distilbert-classification" benchmark_args = TensorFlowBenchmarkArguments( models=[MODEL_ID], training=False, diff --git a/tests/test_pipelines_zero_shot.py b/tests/test_pipelines_zero_shot.py index 20f2666c81..a436fbb407 100644 --- a/tests/test_pipelines_zero_shot.py +++ b/tests/test_pipelines_zero_shot.py @@ -22,9 +22,7 @@ from .test_pipelines_common import CustomInputPipelineCommonMixin class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): pipeline_task = "zero-shot-classification" - small_models = [ - "sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english" - ] # Models tested without the @slow decorator + small_models = ["sgugger/tiny-distilbert-classification"] # Models tested without the @slow decorator large_models = ["roberta-large-mnli"] # Models tested with the @slow decorator valid_inputs = [ {"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics"},