Enhance HfArgumentParser functionality and ease of use (#20323)

* Enhance HfArgumentParser

* Fix type hints for older python versions

* Fix and add tests (+formatting)

* Add changes

* doc-builder formatting

* Remove unused import "Call"
This commit is contained in:
Konstantin Dobler
2022-11-21 18:33:37 +01:00
committed by GitHub
parent 96783e53b4
commit 1e3f17b5ab
2 changed files with 204 additions and 23 deletions

View File

@@ -25,7 +25,15 @@ from typing import List, Optional
import yaml
from transformers import HfArgumentParser, TrainingArguments
from transformers.hf_argparser import string_to_bool
from transformers.hf_argparser import make_choice_type_function, string_to_bool
try:
# For Python versions <3.8, Literal is not in typing: https://peps.python.org/pep-0586/
from typing import Literal
except ImportError:
# For Python 3.7
from typing_extensions import Literal
def list_field(default=None, metadata=None):
@@ -58,6 +66,12 @@ class BasicEnum(Enum):
toto = "toto"
class MixedTypeEnum(Enum):
titi = "titi"
toto = "toto"
fourtytwo = 42
@dataclass
class EnumExample:
foo: BasicEnum = "toto"
@@ -66,6 +80,14 @@ class EnumExample:
self.foo = BasicEnum(self.foo)
@dataclass
class MixedTypeEnumExample:
foo: MixedTypeEnum = "toto"
def __post_init__(self):
self.foo = MixedTypeEnum(self.foo)
@dataclass
class OptionalExample:
foo: Optional[int] = None
@@ -111,6 +133,14 @@ class HfArgumentParserTest(unittest.TestCase):
for x, y in zip(a._actions, b._actions):
xx = {k: v for k, v in vars(x).items() if k != "container"}
yy = {k: v for k, v in vars(y).items() if k != "container"}
# Choices with mixed type have custom function as "type"
# So we need to compare results directly for equality
if xx.get("choices", None) and yy.get("choices", None):
for expected_choice in yy["choices"] + xx["choices"]:
self.assertEqual(xx["type"](expected_choice), yy["type"](expected_choice))
del xx["type"], yy["type"]
self.assertEqual(xx, yy)
def test_basic(self):
@@ -163,21 +193,56 @@ class HfArgumentParserTest(unittest.TestCase):
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
def test_with_enum(self):
parser = HfArgumentParser(EnumExample)
parser = HfArgumentParser(MixedTypeEnumExample)
expected = argparse.ArgumentParser()
expected.add_argument("--foo", default="toto", choices=["titi", "toto"], type=str)
expected.add_argument(
"--foo",
default="toto",
choices=["titi", "toto", 42],
type=make_choice_type_function(["titi", "toto", 42]),
)
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(args.foo, "toto")
enum_ex = parser.parse_args_into_dataclasses([])[0]
self.assertEqual(enum_ex.foo, BasicEnum.toto)
self.assertEqual(enum_ex.foo, MixedTypeEnum.toto)
args = parser.parse_args(["--foo", "titi"])
self.assertEqual(args.foo, "titi")
enum_ex = parser.parse_args_into_dataclasses(["--foo", "titi"])[0]
self.assertEqual(enum_ex.foo, BasicEnum.titi)
self.assertEqual(enum_ex.foo, MixedTypeEnum.titi)
args = parser.parse_args(["--foo", "42"])
self.assertEqual(args.foo, 42)
enum_ex = parser.parse_args_into_dataclasses(["--foo", "42"])[0]
self.assertEqual(enum_ex.foo, MixedTypeEnum.fourtytwo)
def test_with_literal(self):
@dataclass
class LiteralExample:
foo: Literal["titi", "toto", 42] = "toto"
parser = HfArgumentParser(LiteralExample)
expected = argparse.ArgumentParser()
expected.add_argument(
"--foo",
default="toto",
choices=("titi", "toto", 42),
type=make_choice_type_function(["titi", "toto", 42]),
)
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(args.foo, "toto")
args = parser.parse_args(["--foo", "titi"])
self.assertEqual(args.foo, "titi")
args = parser.parse_args(["--foo", "42"])
self.assertEqual(args.foo, 42)
def test_with_list(self):
parser = HfArgumentParser(ListExample)
@@ -222,7 +287,12 @@ class HfArgumentParserTest(unittest.TestCase):
expected = argparse.ArgumentParser()
expected.add_argument("--required_list", nargs="+", type=int, required=True)
expected.add_argument("--required_str", type=str, required=True)
expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True)
expected.add_argument(
"--required_enum",
type=make_choice_type_function(["titi", "toto"]),
choices=["titi", "toto"],
required=True,
)
self.argparsersEqual(parser, expected)
def test_with_string_literal_annotation(self):
@@ -230,7 +300,12 @@ class HfArgumentParserTest(unittest.TestCase):
expected = argparse.ArgumentParser()
expected.add_argument("--foo", type=int, required=True)
expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True)
expected.add_argument(
"--required_enum",
type=make_choice_type_function(["titi", "toto"]),
choices=["titi", "toto"],
required=True,
)
expected.add_argument("--opt", type=string_to_bool, default=None)
expected.add_argument("--baz", default="toto", type=str, help="help message")
expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)