[Benchmark] Memory benchmark utils (#4198)
* improve memory benchmarking * correct typo * fix current memory * check torch memory allocated * better pytorch function * add total cached gpu memory * add total gpu required * improve torch gpu usage * update memory usage * finalize memory tracing * save intermediate benchmark class * fix conflict * improve benchmark * improve benchmark * finalize * make style * improve benchmarking * correct typo * make train function more flexible * fix csv save * better repr of bytes * better print * fix __repr__ bug * finish plot script * rename plot file * delete csv and small improvements * fix in plot * fix in plot * correct usage of timeit * remove redundant line * remove redundant line * fix bug * add hf parser tests * add versioning and platform info * make style * add gpu information * ensure backward compatibility * finish adding all tests * Update src/transformers/benchmark/benchmark_args.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/benchmark/benchmark_args_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * delete csv files * fix isort ordering * add out of memory handling * add better train memory handling Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
ec4cdfdd05
commit
96f57c9ccb
@@ -4,7 +4,7 @@ import sys
|
||||
from argparse import ArgumentParser
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, NewType, Tuple, Union
|
||||
from typing import Any, Iterable, List, NewType, Tuple, Union
|
||||
|
||||
|
||||
DataClass = NewType("DataClass", Any)
|
||||
@@ -52,9 +52,13 @@ class HfArgumentParser(ArgumentParser):
|
||||
"We will add compatibility when Python 3.9 is released."
|
||||
)
|
||||
typestring = str(field.type)
|
||||
for x in (int, float, str):
|
||||
if typestring == f"typing.Union[{x.__name__}, NoneType]":
|
||||
field.type = x
|
||||
for prim_type in (int, float, str):
|
||||
for collection in (List,):
|
||||
if typestring == f"typing.Union[{collection[prim_type]}, NoneType]":
|
||||
field.type = collection[prim_type]
|
||||
if typestring == f"typing.Union[{prim_type.__name__}, NoneType]":
|
||||
field.type = prim_type
|
||||
|
||||
if isinstance(field.type, type) and issubclass(field.type, Enum):
|
||||
kwargs["choices"] = list(field.type)
|
||||
kwargs["type"] = field.type
|
||||
@@ -65,6 +69,14 @@ class HfArgumentParser(ArgumentParser):
|
||||
if field.default is True:
|
||||
field_name = f"--no-{field.name}"
|
||||
kwargs["dest"] = field.name
|
||||
elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
|
||||
kwargs["nargs"] = "+"
|
||||
kwargs["type"] = field.type.__args__[0]
|
||||
assert all(
|
||||
x == kwargs["type"] for x in field.type.__args__
|
||||
), "{} cannot be a List of mixed types".format(field.name)
|
||||
if field.default_factory is not dataclasses.MISSING:
|
||||
kwargs["default"] = field.default_factory()
|
||||
else:
|
||||
kwargs["type"] = field.type
|
||||
if field.default is not dataclasses.MISSING:
|
||||
|
||||
Reference in New Issue
Block a user