[Benchmark] Memory benchmark utils (#4198)
* improve memory benchmarking * correct typo * fix current memory * check torch memory allocated * better pytorch function * add total cached gpu memory * add total gpu required * improve torch gpu usage * update memory usage * finalize memory tracing * save intermediate benchmark class * fix conflict * improve benchmark * improve benchmark * finalize * make style * improve benchmarking * correct typo * make train function more flexible * fix csv save * better repr of bytes * better print * fix __repr__ bug * finish plot script * rename plot file * delete csv and small improvements * fix in plot * fix in plot * correct usage of timeit * remove redundant line * remove redundant line * fix bug * add hf parser tests * add versioning and platform info * make style * add gpu information * ensure backward compatibility * finish adding all tests * Update src/transformers/benchmark/benchmark_args.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/benchmark/benchmark_args_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * delete csv files * fix isort ordering * add out of memory handling * add better train memory handling Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
ec4cdfdd05
commit
96f57c9ccb
@@ -3,11 +3,15 @@ import unittest
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import HfArgumentParser, TrainingArguments
|
||||
|
||||
|
||||
def list_field(default=None, metadata=None):
|
||||
return field(default_factory=lambda: default, metadata=metadata)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicExample:
|
||||
foo: int
|
||||
@@ -43,6 +47,16 @@ class OptionalExample:
|
||||
foo: Optional[int] = None
|
||||
bar: Optional[float] = field(default=None, metadata={"help": "help message"})
|
||||
baz: Optional[str] = None
|
||||
ces: Optional[List[str]] = list_field(default=[])
|
||||
des: Optional[List[int]] = list_field(default=[])
|
||||
|
||||
|
||||
@dataclass
|
||||
class ListExample:
|
||||
foo_int: List[int] = list_field(default=[])
|
||||
bar_int: List[int] = list_field(default=[1, 2, 3])
|
||||
foo_str: List[str] = list_field(default=["Hallo", "Bonjour", "Hello"])
|
||||
foo_float: List[float] = list_field(default=[0.1, 0.2, 0.3])
|
||||
|
||||
|
||||
class HfArgumentParserTest(unittest.TestCase):
|
||||
@@ -101,6 +115,26 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
args = parser.parse_args(["--foo", "titi"])
|
||||
self.assertEqual(args.foo, BasicEnum.titi)
|
||||
|
||||
def test_with_list(self):
|
||||
parser = HfArgumentParser(ListExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo_int", nargs="+", default=[], type=int)
|
||||
expected.add_argument("--bar_int", nargs="+", default=[1, 2, 3], type=int)
|
||||
expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
|
||||
expected.add_argument("--foo_float", nargs="+", default=[0.1, 0.2, 0.3], type=float)
|
||||
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
args = parser.parse_args([])
|
||||
self.assertEqual(
|
||||
args,
|
||||
Namespace(foo_int=[], bar_int=[1, 2, 3], foo_str=["Hallo", "Bonjour", "Hello"], foo_float=[0.1, 0.2, 0.3]),
|
||||
)
|
||||
|
||||
args = parser.parse_args("--foo_int 1 --bar_int 2 3 --foo_str a b c --foo_float 0.1 0.7".split())
|
||||
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
|
||||
|
||||
def test_with_optional(self):
|
||||
parser = HfArgumentParser(OptionalExample)
|
||||
|
||||
@@ -108,13 +142,15 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
expected.add_argument("--foo", default=None, type=int)
|
||||
expected.add_argument("--bar", default=None, type=float, help="help message")
|
||||
expected.add_argument("--baz", default=None, type=str)
|
||||
expected.add_argument("--ces", nargs="+", default=[], type=str)
|
||||
expected.add_argument("--des", nargs="+", default=[], type=int)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
args = parser.parse_args([])
|
||||
self.assertEqual(args, Namespace(foo=None, bar=None, baz=None))
|
||||
self.assertEqual(args, Namespace(foo=None, bar=None, baz=None, ces=[], des=[]))
|
||||
|
||||
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42".split())
|
||||
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42"))
|
||||
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split())
|
||||
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3]))
|
||||
|
||||
def test_integration_training_args(self):
|
||||
parser = HfArgumentParser(TrainingArguments)
|
||||
|
||||
Reference in New Issue
Block a user