Enhance HfArgumentParser functionality and ease of use (#20323)
* Enhance HfArgumentParser * Fix type hints for older python versions * Fix and add tests (+formatting) * Add changes * doc-builder formatting * Remove unused import "Call"
This commit is contained in:
committed by
GitHub
parent
96783e53b4
commit
1e3f17b5ab
@@ -25,7 +25,15 @@ from typing import List, Optional
|
||||
|
||||
import yaml
|
||||
from transformers import HfArgumentParser, TrainingArguments
|
||||
from transformers.hf_argparser import string_to_bool
|
||||
from transformers.hf_argparser import make_choice_type_function, string_to_bool
|
||||
|
||||
|
||||
try:
|
||||
# For Python versions <3.8, Literal is not in typing: https://peps.python.org/pep-0586/
|
||||
from typing import Literal
|
||||
except ImportError:
|
||||
# For Python 3.7
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
def list_field(default=None, metadata=None):
|
||||
@@ -58,6 +66,12 @@ class BasicEnum(Enum):
|
||||
toto = "toto"
|
||||
|
||||
|
||||
class MixedTypeEnum(Enum):
|
||||
titi = "titi"
|
||||
toto = "toto"
|
||||
fourtytwo = 42
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnumExample:
|
||||
foo: BasicEnum = "toto"
|
||||
@@ -66,6 +80,14 @@ class EnumExample:
|
||||
self.foo = BasicEnum(self.foo)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MixedTypeEnumExample:
|
||||
foo: MixedTypeEnum = "toto"
|
||||
|
||||
def __post_init__(self):
|
||||
self.foo = MixedTypeEnum(self.foo)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptionalExample:
|
||||
foo: Optional[int] = None
|
||||
@@ -111,6 +133,14 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
for x, y in zip(a._actions, b._actions):
|
||||
xx = {k: v for k, v in vars(x).items() if k != "container"}
|
||||
yy = {k: v for k, v in vars(y).items() if k != "container"}
|
||||
|
||||
# Choices with mixed type have custom function as "type"
|
||||
# So we need to compare results directly for equality
|
||||
if xx.get("choices", None) and yy.get("choices", None):
|
||||
for expected_choice in yy["choices"] + xx["choices"]:
|
||||
self.assertEqual(xx["type"](expected_choice), yy["type"](expected_choice))
|
||||
del xx["type"], yy["type"]
|
||||
|
||||
self.assertEqual(xx, yy)
|
||||
|
||||
def test_basic(self):
|
||||
@@ -163,21 +193,56 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
|
||||
|
||||
def test_with_enum(self):
|
||||
parser = HfArgumentParser(EnumExample)
|
||||
parser = HfArgumentParser(MixedTypeEnumExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo", default="toto", choices=["titi", "toto"], type=str)
|
||||
expected.add_argument(
|
||||
"--foo",
|
||||
default="toto",
|
||||
choices=["titi", "toto", 42],
|
||||
type=make_choice_type_function(["titi", "toto", 42]),
|
||||
)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
args = parser.parse_args([])
|
||||
self.assertEqual(args.foo, "toto")
|
||||
enum_ex = parser.parse_args_into_dataclasses([])[0]
|
||||
self.assertEqual(enum_ex.foo, BasicEnum.toto)
|
||||
self.assertEqual(enum_ex.foo, MixedTypeEnum.toto)
|
||||
|
||||
args = parser.parse_args(["--foo", "titi"])
|
||||
self.assertEqual(args.foo, "titi")
|
||||
enum_ex = parser.parse_args_into_dataclasses(["--foo", "titi"])[0]
|
||||
self.assertEqual(enum_ex.foo, BasicEnum.titi)
|
||||
self.assertEqual(enum_ex.foo, MixedTypeEnum.titi)
|
||||
|
||||
args = parser.parse_args(["--foo", "42"])
|
||||
self.assertEqual(args.foo, 42)
|
||||
enum_ex = parser.parse_args_into_dataclasses(["--foo", "42"])[0]
|
||||
self.assertEqual(enum_ex.foo, MixedTypeEnum.fourtytwo)
|
||||
|
||||
def test_with_literal(self):
|
||||
@dataclass
|
||||
class LiteralExample:
|
||||
foo: Literal["titi", "toto", 42] = "toto"
|
||||
|
||||
parser = HfArgumentParser(LiteralExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument(
|
||||
"--foo",
|
||||
default="toto",
|
||||
choices=("titi", "toto", 42),
|
||||
type=make_choice_type_function(["titi", "toto", 42]),
|
||||
)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
args = parser.parse_args([])
|
||||
self.assertEqual(args.foo, "toto")
|
||||
|
||||
args = parser.parse_args(["--foo", "titi"])
|
||||
self.assertEqual(args.foo, "titi")
|
||||
|
||||
args = parser.parse_args(["--foo", "42"])
|
||||
self.assertEqual(args.foo, 42)
|
||||
|
||||
def test_with_list(self):
|
||||
parser = HfArgumentParser(ListExample)
|
||||
@@ -222,7 +287,12 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--required_list", nargs="+", type=int, required=True)
|
||||
expected.add_argument("--required_str", type=str, required=True)
|
||||
expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True)
|
||||
expected.add_argument(
|
||||
"--required_enum",
|
||||
type=make_choice_type_function(["titi", "toto"]),
|
||||
choices=["titi", "toto"],
|
||||
required=True,
|
||||
)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
def test_with_string_literal_annotation(self):
|
||||
@@ -230,7 +300,12 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo", type=int, required=True)
|
||||
expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True)
|
||||
expected.add_argument(
|
||||
"--required_enum",
|
||||
type=make_choice_type_function(["titi", "toto"]),
|
||||
choices=["titi", "toto"],
|
||||
required=True,
|
||||
)
|
||||
expected.add_argument("--opt", type=string_to_bool, default=None)
|
||||
expected.add_argument("--baz", default="toto", type=str, help="help message")
|
||||
expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
|
||||
|
||||
Reference in New Issue
Block a user