[Benchmark] Extend Benchmark to all model type extensions (#5241)
* add benchmark for all kinds of models * improved import * delete bogus files * make style
This commit is contained in:
committed by
GitHub
parent
7c41057d50
commit
9fe09cec76
@@ -38,6 +38,22 @@ class BenchmarkTest(unittest.TestCase):
|
||||
self.check_results_dict_not_empty(results.time_inference_result)
|
||||
self.check_results_dict_not_empty(results.memory_inference_result)
|
||||
|
||||
def test_inference_no_configs_only_pretrain(self):
|
||||
MODEL_ID = "sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
|
||||
benchmark_args = PyTorchBenchmarkArguments(
|
||||
models=[MODEL_ID],
|
||||
training=False,
|
||||
no_inference=False,
|
||||
sequence_lengths=[8],
|
||||
batch_sizes=[1],
|
||||
no_multi_process=True,
|
||||
only_pretrain_model=True,
|
||||
)
|
||||
benchmark = PyTorchBenchmark(benchmark_args)
|
||||
results = benchmark.run()
|
||||
self.check_results_dict_not_empty(results.time_inference_result)
|
||||
self.check_results_dict_not_empty(results.memory_inference_result)
|
||||
|
||||
def test_inference_torchscript(self):
|
||||
MODEL_ID = "sshleifer/tiny-gpt2"
|
||||
benchmark_args = PyTorchBenchmarkArguments(
|
||||
|
||||
Reference in New Issue
Block a user