[Benchmark] Memory benchmark utils (#4198)

* improve memory benchmarking

* correct typo

* fix current memory

* check torch memory allocated

* better pytorch function

* add total cached gpu memory

* add total gpu required

* improve torch gpu usage

* update memory usage

* finalize memory tracing

* save intermediate benchmark class

* fix conflict

* improve benchmark

* improve benchmark

* finalize

* make style

* improve benchmarking

* correct typo

* make train function more flexible

* fix csv save

* better repr of bytes

* better print

* fix __repr__ bug

* finish plot script

* rename plot file

* delete csv and small improvements

* fix in plot

* fix in plot

* correct usage of timeit

* remove redundant line

* remove redundant line

* fix bug

* add hf parser tests

* add versioning and platform info

* make style

* add gpu information

* ensure backward compatibility

* finish adding all tests

* Update src/transformers/benchmark/benchmark_args.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/benchmark/benchmark_args_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* delete csv files

* fix isort ordering

* add out of memory handling

* add better train memory handling

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Patrick von Platen
2020-05-27 23:22:16 +02:00
committed by GitHub
parent ec4cdfdd05
commit 96f57c9ccb
14 changed files with 934 additions and 744 deletions

View File

@@ -4,7 +4,7 @@ import sys
from argparse import ArgumentParser
from enum import Enum
from pathlib import Path
from typing import Any, Iterable, NewType, Tuple, Union
from typing import Any, Iterable, List, NewType, Tuple, Union
DataClass = NewType("DataClass", Any)
@@ -52,9 +52,13 @@ class HfArgumentParser(ArgumentParser):
"We will add compatibility when Python 3.9 is released."
)
typestring = str(field.type)
for x in (int, float, str):
if typestring == f"typing.Union[{x.__name__}, NoneType]":
field.type = x
for prim_type in (int, float, str):
for collection in (List,):
if typestring == f"typing.Union[{collection[prim_type]}, NoneType]":
field.type = collection[prim_type]
if typestring == f"typing.Union[{prim_type.__name__}, NoneType]":
field.type = prim_type
if isinstance(field.type, type) and issubclass(field.type, Enum):
kwargs["choices"] = list(field.type)
kwargs["type"] = field.type
@@ -65,6 +69,14 @@ class HfArgumentParser(ArgumentParser):
if field.default is True:
field_name = f"--no-{field.name}"
kwargs["dest"] = field.name
elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
kwargs["nargs"] = "+"
kwargs["type"] = field.type.__args__[0]
assert all(
x == kwargs["type"] for x in field.type.__args__
), "{} cannot be a List of mixed types".format(field.name)
if field.default_factory is not dataclasses.MISSING:
kwargs["default"] = field.default_factory()
else:
kwargs["type"] = field.type
if field.default is not dataclasses.MISSING: