Fix from_pretrained with corrupted state_dict (#12939)
* Fix from_pretrained with corrupted state_dict * Adapt test * Use better checkpoint * Style * Clean up
This commit is contained in:
@@ -49,7 +49,7 @@ class BenchmarkTest(unittest.TestCase):
|
||||
self.check_results_dict_not_empty(results.memory_inference_result)
|
||||
|
||||
def test_inference_no_configs_only_pretrain(self):
|
||||
MODEL_ID = "sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
|
||||
MODEL_ID = "sgugger/tiny-distilbert-classification"
|
||||
benchmark_args = PyTorchBenchmarkArguments(
|
||||
models=[MODEL_ID],
|
||||
training=False,
|
||||
|
||||
@@ -52,7 +52,7 @@ class TFBenchmarkTest(unittest.TestCase):
|
||||
self.check_results_dict_not_empty(results.memory_inference_result)
|
||||
|
||||
def test_inference_no_configs_only_pretrain(self):
|
||||
MODEL_ID = "sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
|
||||
MODEL_ID = "sgugger/tiny-distilbert-classification"
|
||||
benchmark_args = TensorFlowBenchmarkArguments(
|
||||
models=[MODEL_ID],
|
||||
training=False,
|
||||
|
||||
@@ -22,9 +22,7 @@ from .test_pipelines_common import CustomInputPipelineCommonMixin
|
||||
|
||||
class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||
pipeline_task = "zero-shot-classification"
|
||||
small_models = [
|
||||
"sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
|
||||
] # Models tested without the @slow decorator
|
||||
small_models = ["sgugger/tiny-distilbert-classification"] # Models tested without the @slow decorator
|
||||
large_models = ["roberta-large-mnli"] # Models tested with the @slow decorator
|
||||
valid_inputs = [
|
||||
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics"},
|
||||
|
||||
Reference in New Issue
Block a user