[Benchmark] Memory benchmark utils (#4198)
* improve memory benchmarking * correct typo * fix current memory * check torch memory allocated * better pytorch function * add total cached gpu memory * add total gpu required * improve torch gpu usage * update memory usage * finalize memory tracing * save intermediate benchmark class * fix conflict * improve benchmark * improve benchmark * finalize * make style * improve benchmarking * correct typo * make train function more flexible * fix csv save * better repr of bytes * better print * fix __repr__ bug * finish plot script * rename plot file * delete csv and small improvements * fix in plot * fix in plot * correct usage of timeit * remove redundant line * remove redundant line * fix bug * add hf parser tests * add versioning and platform info * make style * add gpu information * ensure backward compatibility * finish adding all tests * Update src/transformers/benchmark/benchmark_args.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/benchmark/benchmark_args_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * delete csv files * fix isort ordering * add out of memory handling * add better train memory handling Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
ec4cdfdd05
commit
96f57c9ccb
90
tests/test_benchmark.py
Normal file
90
tests/test_benchmark.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import GPT2Config, is_torch_available
|
||||
|
||||
from .utils import require_torch
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (
|
||||
PyTorchBenchmarkArguments,
|
||||
PyTorchBenchmark,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BenchmarkTest(unittest.TestCase):
|
||||
def check_results_dict_not_empty(self, results):
|
||||
for model_result in results.values():
|
||||
for batch_size, sequence_length in zip(model_result["bs"], model_result["ss"]):
|
||||
result = model_result["result"][batch_size][sequence_length]
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
def test_inference_no_configs(self):
|
||||
MODEL_ID = "sshleifer/tiny-gpt2"
|
||||
benchmark_args = PyTorchBenchmarkArguments(
|
||||
models=[MODEL_ID], training=False, no_inference=False, sequence_lengths=[8], batch_sizes=[1]
|
||||
)
|
||||
benchmark = PyTorchBenchmark(benchmark_args)
|
||||
results = benchmark.run()
|
||||
self.check_results_dict_not_empty(results.time_inference_result)
|
||||
self.check_results_dict_not_empty(results.memory_inference_result)
|
||||
|
||||
def test_train_no_configs(self):
|
||||
MODEL_ID = "sshleifer/tiny-gpt2"
|
||||
benchmark_args = PyTorchBenchmarkArguments(
|
||||
models=[MODEL_ID], training=True, no_inference=True, sequence_lengths=[8], batch_sizes=[1]
|
||||
)
|
||||
benchmark = PyTorchBenchmark(benchmark_args)
|
||||
results = benchmark.run()
|
||||
self.check_results_dict_not_empty(results.time_train_result)
|
||||
self.check_results_dict_not_empty(results.memory_train_result)
|
||||
|
||||
def test_inference_with_configs(self):
|
||||
MODEL_ID = "sshleifer/tiny-gpt2"
|
||||
config = GPT2Config.from_pretrained(MODEL_ID)
|
||||
benchmark_args = PyTorchBenchmarkArguments(
|
||||
models=[MODEL_ID], training=False, no_inference=False, sequence_lengths=[8], batch_sizes=[1]
|
||||
)
|
||||
benchmark = PyTorchBenchmark(benchmark_args, configs=[config])
|
||||
results = benchmark.run()
|
||||
self.check_results_dict_not_empty(results.time_inference_result)
|
||||
self.check_results_dict_not_empty(results.memory_inference_result)
|
||||
|
||||
def test_train_with_configs(self):
|
||||
MODEL_ID = "sshleifer/tiny-gpt2"
|
||||
config = GPT2Config.from_pretrained(MODEL_ID)
|
||||
benchmark_args = PyTorchBenchmarkArguments(
|
||||
models=[MODEL_ID], training=True, no_inference=True, sequence_lengths=[8], batch_sizes=[1]
|
||||
)
|
||||
benchmark = PyTorchBenchmark(benchmark_args, configs=[config])
|
||||
results = benchmark.run()
|
||||
self.check_results_dict_not_empty(results.time_train_result)
|
||||
self.check_results_dict_not_empty(results.memory_train_result)
|
||||
|
||||
def test_save_csv_files(self):
|
||||
MODEL_ID = "sshleifer/tiny-gpt2"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
benchmark_args = PyTorchBenchmarkArguments(
|
||||
models=[MODEL_ID],
|
||||
training=True,
|
||||
no_inference=False,
|
||||
save_to_csv=True,
|
||||
sequence_lengths=[8],
|
||||
batch_sizes=[1],
|
||||
inference_time_csv_file=os.path.join(tmp_dir, "inf_time.csv"),
|
||||
train_memory_csv_file=os.path.join(tmp_dir, "train_mem.csv"),
|
||||
inference_memory_csv_file=os.path.join(tmp_dir, "inf_mem.csv"),
|
||||
train_time_csv_file=os.path.join(tmp_dir, "train_time.csv"),
|
||||
env_info_csv_file=os.path.join(tmp_dir, "env.csv"),
|
||||
)
|
||||
benchmark = PyTorchBenchmark(benchmark_args)
|
||||
benchmark.run()
|
||||
self.assertTrue(Path(os.path.join(tmp_dir, "inf_time.csv")).exists())
|
||||
self.assertTrue(Path(os.path.join(tmp_dir, "train_time.csv")).exists())
|
||||
self.assertTrue(Path(os.path.join(tmp_dir, "inf_mem.csv")).exists())
|
||||
self.assertTrue(Path(os.path.join(tmp_dir, "train_mem.csv")).exists())
|
||||
self.assertTrue(Path(os.path.join(tmp_dir, "env.csv")).exists())
|
||||
@@ -3,11 +3,15 @@ import unittest
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import HfArgumentParser, TrainingArguments
|
||||
|
||||
|
||||
def list_field(default=None, metadata=None):
|
||||
return field(default_factory=lambda: default, metadata=metadata)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicExample:
|
||||
foo: int
|
||||
@@ -43,6 +47,16 @@ class OptionalExample:
|
||||
foo: Optional[int] = None
|
||||
bar: Optional[float] = field(default=None, metadata={"help": "help message"})
|
||||
baz: Optional[str] = None
|
||||
ces: Optional[List[str]] = list_field(default=[])
|
||||
des: Optional[List[int]] = list_field(default=[])
|
||||
|
||||
|
||||
@dataclass
|
||||
class ListExample:
|
||||
foo_int: List[int] = list_field(default=[])
|
||||
bar_int: List[int] = list_field(default=[1, 2, 3])
|
||||
foo_str: List[str] = list_field(default=["Hallo", "Bonjour", "Hello"])
|
||||
foo_float: List[float] = list_field(default=[0.1, 0.2, 0.3])
|
||||
|
||||
|
||||
class HfArgumentParserTest(unittest.TestCase):
|
||||
@@ -101,6 +115,26 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
args = parser.parse_args(["--foo", "titi"])
|
||||
self.assertEqual(args.foo, BasicEnum.titi)
|
||||
|
||||
def test_with_list(self):
|
||||
parser = HfArgumentParser(ListExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo_int", nargs="+", default=[], type=int)
|
||||
expected.add_argument("--bar_int", nargs="+", default=[1, 2, 3], type=int)
|
||||
expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
|
||||
expected.add_argument("--foo_float", nargs="+", default=[0.1, 0.2, 0.3], type=float)
|
||||
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
args = parser.parse_args([])
|
||||
self.assertEqual(
|
||||
args,
|
||||
Namespace(foo_int=[], bar_int=[1, 2, 3], foo_str=["Hallo", "Bonjour", "Hello"], foo_float=[0.1, 0.2, 0.3]),
|
||||
)
|
||||
|
||||
args = parser.parse_args("--foo_int 1 --bar_int 2 3 --foo_str a b c --foo_float 0.1 0.7".split())
|
||||
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
|
||||
|
||||
def test_with_optional(self):
|
||||
parser = HfArgumentParser(OptionalExample)
|
||||
|
||||
@@ -108,13 +142,15 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
expected.add_argument("--foo", default=None, type=int)
|
||||
expected.add_argument("--bar", default=None, type=float, help="help message")
|
||||
expected.add_argument("--baz", default=None, type=str)
|
||||
expected.add_argument("--ces", nargs="+", default=[], type=str)
|
||||
expected.add_argument("--des", nargs="+", default=[], type=int)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
args = parser.parse_args([])
|
||||
self.assertEqual(args, Namespace(foo=None, bar=None, baz=None))
|
||||
self.assertEqual(args, Namespace(foo=None, bar=None, baz=None, ces=[], des=[]))
|
||||
|
||||
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42".split())
|
||||
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42"))
|
||||
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split())
|
||||
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3]))
|
||||
|
||||
def test_integration_training_args(self):
|
||||
parser = HfArgumentParser(TrainingArguments)
|
||||
|
||||
Reference in New Issue
Block a user