Fix from_pretrained with corrupted state_dict (#12939)
* Fix from_pretrained with corrupted state_dict * Adapt test * Use better checkpoint * Style * Clean up
This commit is contained in:
@@ -1409,6 +1409,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
add_prefix = has_prefix_module and not expects_prefix_module
|
add_prefix = has_prefix_module and not expects_prefix_module
|
||||||
|
|
||||||
if remove_prefix:
|
if remove_prefix:
|
||||||
|
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(prefix)]
|
||||||
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys]
|
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys]
|
||||||
elif add_prefix:
|
elif add_prefix:
|
||||||
expected_keys = [".".join([prefix, s]) for s in expected_keys]
|
expected_keys = [".".join([prefix, s]) for s in expected_keys]
|
||||||
@@ -1490,6 +1491,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
start_prefix = cls.base_model_prefix + "."
|
start_prefix = cls.base_model_prefix + "."
|
||||||
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
||||||
model_to_load = getattr(model, cls.base_model_prefix)
|
model_to_load = getattr(model, cls.base_model_prefix)
|
||||||
|
if any(key in expected_keys_not_prefixed for key in loaded_keys):
|
||||||
|
raise ValueError(
|
||||||
|
"The state dictionary of the model you are training to load is corrupted. Are you sure it was "
|
||||||
|
"properly saved?"
|
||||||
|
)
|
||||||
|
|
||||||
load(model_to_load, prefix=start_prefix)
|
load(model_to_load, prefix=start_prefix)
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ class BenchmarkTest(unittest.TestCase):
|
|||||||
self.check_results_dict_not_empty(results.memory_inference_result)
|
self.check_results_dict_not_empty(results.memory_inference_result)
|
||||||
|
|
||||||
def test_inference_no_configs_only_pretrain(self):
|
def test_inference_no_configs_only_pretrain(self):
|
||||||
MODEL_ID = "sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
|
MODEL_ID = "sgugger/tiny-distilbert-classification"
|
||||||
benchmark_args = PyTorchBenchmarkArguments(
|
benchmark_args = PyTorchBenchmarkArguments(
|
||||||
models=[MODEL_ID],
|
models=[MODEL_ID],
|
||||||
training=False,
|
training=False,
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ class TFBenchmarkTest(unittest.TestCase):
|
|||||||
self.check_results_dict_not_empty(results.memory_inference_result)
|
self.check_results_dict_not_empty(results.memory_inference_result)
|
||||||
|
|
||||||
def test_inference_no_configs_only_pretrain(self):
|
def test_inference_no_configs_only_pretrain(self):
|
||||||
MODEL_ID = "sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
|
MODEL_ID = "sgugger/tiny-distilbert-classification"
|
||||||
benchmark_args = TensorFlowBenchmarkArguments(
|
benchmark_args = TensorFlowBenchmarkArguments(
|
||||||
models=[MODEL_ID],
|
models=[MODEL_ID],
|
||||||
training=False,
|
training=False,
|
||||||
|
|||||||
@@ -22,9 +22,7 @@ from .test_pipelines_common import CustomInputPipelineCommonMixin
|
|||||||
|
|
||||||
class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||||
pipeline_task = "zero-shot-classification"
|
pipeline_task = "zero-shot-classification"
|
||||||
small_models = [
|
small_models = ["sgugger/tiny-distilbert-classification"] # Models tested without the @slow decorator
|
||||||
"sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
|
|
||||||
] # Models tested without the @slow decorator
|
|
||||||
large_models = ["roberta-large-mnli"] # Models tested with the @slow decorator
|
large_models = ["roberta-large-mnli"] # Models tested with the @slow decorator
|
||||||
valid_inputs = [
|
valid_inputs = [
|
||||||
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics"},
|
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics"},
|
||||||
|
|||||||
Reference in New Issue
Block a user