Enhance HfArgumentParser functionality and ease of use (#20323)
* Enhance HfArgumentParser * Fix type hints for older python versions * Fix and add tests (+formatting) * Add changes * doc-builder formatting * Remove unused import "Call"
This commit is contained in:
committed by
GitHub
parent
96783e53b4
commit
1e3f17b5ab
@@ -20,11 +20,19 @@ from copy import copy
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from inspect import isclass
|
from inspect import isclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Iterable, NewType, Optional, Tuple, Union, get_type_hints
|
from typing import Any, Callable, Dict, Iterable, List, NewType, Optional, Tuple, Union, get_type_hints
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
# For Python versions <3.8, Literal is not in typing: https://peps.python.org/pep-0586/
|
||||||
|
from typing import Literal
|
||||||
|
except ImportError:
|
||||||
|
# For Python 3.7
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
DataClass = NewType("DataClass", Any)
|
DataClass = NewType("DataClass", Any)
|
||||||
DataClassType = NewType("DataClassType", Any)
|
DataClassType = NewType("DataClassType", Any)
|
||||||
|
|
||||||
@@ -43,6 +51,68 @@ def string_to_bool(v):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_choice_type_function(choices: list) -> Callable[[str], Any]:
|
||||||
|
"""
|
||||||
|
Creates a mapping function from each choices string representation to the actual value. Used to support multiple
|
||||||
|
value types for a single argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
choices (list): List of choices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[str], Any]: Mapping function from string representation to actual value for each choice.
|
||||||
|
"""
|
||||||
|
str_to_choice = {str(choice): choice for choice in choices}
|
||||||
|
return lambda arg: str_to_choice.get(arg, arg)
|
||||||
|
|
||||||
|
|
||||||
|
def HfArg(
|
||||||
|
*,
|
||||||
|
aliases: Union[str, List[str]] = None,
|
||||||
|
help: str = None,
|
||||||
|
default: Any = dataclasses.MISSING,
|
||||||
|
default_factory: Callable[[], Any] = dataclasses.MISSING,
|
||||||
|
metadata: dict = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> dataclasses.Field:
|
||||||
|
"""Argument helper enabling a concise syntax to create dataclass fields for parsing with `HfArgumentParser`.
|
||||||
|
|
||||||
|
Example comparing the use of `HfArg` and `dataclasses.field`:
|
||||||
|
```
|
||||||
|
@dataclass
|
||||||
|
class Args:
|
||||||
|
regular_arg: str = dataclasses.field(default="Huggingface", metadata={"aliases": ["--example", "-e"], "help": "This syntax could be better!"})
|
||||||
|
hf_arg: str = HfArg(default="Huggingface", aliases=["--example", "-e"], help="What a nice syntax!")
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
aliases (Union[str, List[str]], optional):
|
||||||
|
Single string or list of strings of aliases to pass on to argparse, e.g. `aliases=["--example", "-e"]`.
|
||||||
|
Defaults to None.
|
||||||
|
help (str, optional): Help string to pass on to argparse that can be displayed with --help. Defaults to None.
|
||||||
|
default (Any, optional):
|
||||||
|
Default value for the argument. If not default or default_factory is specified, the argument is required.
|
||||||
|
Defaults to dataclasses.MISSING.
|
||||||
|
default_factory (Callable[[], Any], optional):
|
||||||
|
The default_factory is a 0-argument function called to initialize a field's value. It is useful to provide
|
||||||
|
default values for mutable types, e.g. lists: `default_factory=list`. Mutually exclusive with `default=`.
|
||||||
|
Defaults to dataclasses.MISSING.
|
||||||
|
metadata (dict, optional): Further metadata to pass on to `dataclasses.field`. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Field: A `dataclasses.Field` with the desired properties.
|
||||||
|
"""
|
||||||
|
if metadata is None:
|
||||||
|
# Important, don't use as default param in function signature because dict is mutable and shared across function calls
|
||||||
|
metadata = {}
|
||||||
|
if aliases is not None:
|
||||||
|
metadata["aliases"] = aliases
|
||||||
|
if help is not None:
|
||||||
|
metadata["help"] = help
|
||||||
|
|
||||||
|
return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class HfArgumentParser(ArgumentParser):
|
class HfArgumentParser(ArgumentParser):
|
||||||
"""
|
"""
|
||||||
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
|
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
|
||||||
@@ -84,6 +154,10 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
"`typing.get_type_hints` method by default"
|
"`typing.get_type_hints` method by default"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
aliases = kwargs.pop("aliases", [])
|
||||||
|
if isinstance(aliases, str):
|
||||||
|
aliases = [aliases]
|
||||||
|
|
||||||
origin_type = getattr(field.type, "__origin__", field.type)
|
origin_type = getattr(field.type, "__origin__", field.type)
|
||||||
if origin_type is Union:
|
if origin_type is Union:
|
||||||
if str not in field.type.__args__ and (
|
if str not in field.type.__args__ and (
|
||||||
@@ -108,9 +182,14 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
# A variable to store kwargs for a boolean field, if needed
|
# A variable to store kwargs for a boolean field, if needed
|
||||||
# so that we can init a `no_*` complement argument (see below)
|
# so that we can init a `no_*` complement argument (see below)
|
||||||
bool_kwargs = {}
|
bool_kwargs = {}
|
||||||
if isinstance(field.type, type) and issubclass(field.type, Enum):
|
if origin_type is Literal or (isinstance(field.type, type) and issubclass(field.type, Enum)):
|
||||||
kwargs["choices"] = [x.value for x in field.type]
|
if origin_type is Literal:
|
||||||
kwargs["type"] = type(kwargs["choices"][0])
|
kwargs["choices"] = field.type.__args__
|
||||||
|
else:
|
||||||
|
kwargs["choices"] = [x.value for x in field.type]
|
||||||
|
|
||||||
|
kwargs["type"] = make_choice_type_function(kwargs["choices"])
|
||||||
|
|
||||||
if field.default is not dataclasses.MISSING:
|
if field.default is not dataclasses.MISSING:
|
||||||
kwargs["default"] = field.default
|
kwargs["default"] = field.default
|
||||||
else:
|
else:
|
||||||
@@ -146,7 +225,7 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
kwargs["default"] = field.default_factory()
|
kwargs["default"] = field.default_factory()
|
||||||
else:
|
else:
|
||||||
kwargs["required"] = True
|
kwargs["required"] = True
|
||||||
parser.add_argument(field_name, **kwargs)
|
parser.add_argument(field_name, *aliases, **kwargs)
|
||||||
|
|
||||||
# Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added.
|
# Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added.
|
||||||
# Order is important for arguments with the same destination!
|
# Order is important for arguments with the same destination!
|
||||||
@@ -178,7 +257,12 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
self._parse_dataclass_field(parser, field)
|
self._parse_dataclass_field(parser, field)
|
||||||
|
|
||||||
def parse_args_into_dataclasses(
|
def parse_args_into_dataclasses(
|
||||||
self, args=None, return_remaining_strings=False, look_for_args_file=True, args_filename=None
|
self,
|
||||||
|
args=None,
|
||||||
|
return_remaining_strings=False,
|
||||||
|
look_for_args_file=True,
|
||||||
|
args_filename=None,
|
||||||
|
args_file_flag=None,
|
||||||
) -> Tuple[DataClass, ...]:
|
) -> Tuple[DataClass, ...]:
|
||||||
"""
|
"""
|
||||||
Parse command-line args into instances of the specified dataclass types.
|
Parse command-line args into instances of the specified dataclass types.
|
||||||
@@ -196,6 +280,9 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
process, and will append its potential content to the command line args.
|
process, and will append its potential content to the command line args.
|
||||||
args_filename:
|
args_filename:
|
||||||
If not None, will uses this file instead of the ".args" file specified in the previous argument.
|
If not None, will uses this file instead of the ".args" file specified in the previous argument.
|
||||||
|
args_file_flag:
|
||||||
|
If not None, will look for a file in the command-line args specified with this flag. The flag can be
|
||||||
|
specified multiple times and precedence is determined by the order (last one wins).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple consisting of:
|
Tuple consisting of:
|
||||||
@@ -205,17 +292,36 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
after initialization.
|
after initialization.
|
||||||
- The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args)
|
- The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args)
|
||||||
"""
|
"""
|
||||||
if args_filename or (look_for_args_file and len(sys.argv)):
|
|
||||||
if args_filename:
|
|
||||||
args_file = Path(args_filename)
|
|
||||||
else:
|
|
||||||
args_file = Path(sys.argv[0]).with_suffix(".args")
|
|
||||||
|
|
||||||
if args_file.exists():
|
if args_file_flag or args_filename or (look_for_args_file and len(sys.argv)):
|
||||||
fargs = args_file.read_text().split()
|
args_files = []
|
||||||
args = fargs + args if args is not None else fargs + sys.argv[1:]
|
|
||||||
# in case of duplicate arguments the first one has precedence
|
if args_filename:
|
||||||
# so we append rather than prepend.
|
args_files.append(Path(args_filename))
|
||||||
|
elif look_for_args_file and len(sys.argv):
|
||||||
|
args_files.append(Path(sys.argv[0]).with_suffix(".args"))
|
||||||
|
|
||||||
|
# args files specified via command line flag should overwrite default args files so we add them last
|
||||||
|
if args_file_flag:
|
||||||
|
# Create special parser just to extract the args_file_flag values
|
||||||
|
args_file_parser = ArgumentParser()
|
||||||
|
args_file_parser.add_argument(args_file_flag, type=str, action="append")
|
||||||
|
|
||||||
|
# Use only remaining args for further parsing (remove the args_file_flag)
|
||||||
|
cfg, args = args_file_parser.parse_known_args(args=args)
|
||||||
|
cmd_args_file_paths = vars(cfg).get(args_file_flag.lstrip("-"), None)
|
||||||
|
|
||||||
|
if cmd_args_file_paths:
|
||||||
|
args_files.extend([Path(p) for p in cmd_args_file_paths])
|
||||||
|
|
||||||
|
file_args = []
|
||||||
|
for args_file in args_files:
|
||||||
|
if args_file.exists():
|
||||||
|
file_args += args_file.read_text().split()
|
||||||
|
|
||||||
|
# in case of duplicate arguments the last one has precedence
|
||||||
|
# args specified via the command line should overwrite args from files, so we add them last
|
||||||
|
args = file_args + args if args is not None else file_args + sys.argv[1:]
|
||||||
namespace, remaining_args = self.parse_known_args(args=args)
|
namespace, remaining_args = self.parse_known_args(args=args)
|
||||||
outputs = []
|
outputs = []
|
||||||
for dtype in self.dataclass_types:
|
for dtype in self.dataclass_types:
|
||||||
|
|||||||
@@ -25,7 +25,15 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from transformers import HfArgumentParser, TrainingArguments
|
from transformers import HfArgumentParser, TrainingArguments
|
||||||
from transformers.hf_argparser import string_to_bool
|
from transformers.hf_argparser import make_choice_type_function, string_to_bool
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
# For Python versions <3.8, Literal is not in typing: https://peps.python.org/pep-0586/
|
||||||
|
from typing import Literal
|
||||||
|
except ImportError:
|
||||||
|
# For Python 3.7
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
def list_field(default=None, metadata=None):
|
def list_field(default=None, metadata=None):
|
||||||
@@ -58,6 +66,12 @@ class BasicEnum(Enum):
|
|||||||
toto = "toto"
|
toto = "toto"
|
||||||
|
|
||||||
|
|
||||||
|
class MixedTypeEnum(Enum):
|
||||||
|
titi = "titi"
|
||||||
|
toto = "toto"
|
||||||
|
fourtytwo = 42
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EnumExample:
|
class EnumExample:
|
||||||
foo: BasicEnum = "toto"
|
foo: BasicEnum = "toto"
|
||||||
@@ -66,6 +80,14 @@ class EnumExample:
|
|||||||
self.foo = BasicEnum(self.foo)
|
self.foo = BasicEnum(self.foo)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MixedTypeEnumExample:
|
||||||
|
foo: MixedTypeEnum = "toto"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.foo = MixedTypeEnum(self.foo)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OptionalExample:
|
class OptionalExample:
|
||||||
foo: Optional[int] = None
|
foo: Optional[int] = None
|
||||||
@@ -111,6 +133,14 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||||||
for x, y in zip(a._actions, b._actions):
|
for x, y in zip(a._actions, b._actions):
|
||||||
xx = {k: v for k, v in vars(x).items() if k != "container"}
|
xx = {k: v for k, v in vars(x).items() if k != "container"}
|
||||||
yy = {k: v for k, v in vars(y).items() if k != "container"}
|
yy = {k: v for k, v in vars(y).items() if k != "container"}
|
||||||
|
|
||||||
|
# Choices with mixed type have custom function as "type"
|
||||||
|
# So we need to compare results directly for equality
|
||||||
|
if xx.get("choices", None) and yy.get("choices", None):
|
||||||
|
for expected_choice in yy["choices"] + xx["choices"]:
|
||||||
|
self.assertEqual(xx["type"](expected_choice), yy["type"](expected_choice))
|
||||||
|
del xx["type"], yy["type"]
|
||||||
|
|
||||||
self.assertEqual(xx, yy)
|
self.assertEqual(xx, yy)
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
@@ -163,21 +193,56 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||||||
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
|
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
|
||||||
|
|
||||||
def test_with_enum(self):
|
def test_with_enum(self):
|
||||||
parser = HfArgumentParser(EnumExample)
|
parser = HfArgumentParser(MixedTypeEnumExample)
|
||||||
|
|
||||||
expected = argparse.ArgumentParser()
|
expected = argparse.ArgumentParser()
|
||||||
expected.add_argument("--foo", default="toto", choices=["titi", "toto"], type=str)
|
expected.add_argument(
|
||||||
|
"--foo",
|
||||||
|
default="toto",
|
||||||
|
choices=["titi", "toto", 42],
|
||||||
|
type=make_choice_type_function(["titi", "toto", 42]),
|
||||||
|
)
|
||||||
self.argparsersEqual(parser, expected)
|
self.argparsersEqual(parser, expected)
|
||||||
|
|
||||||
args = parser.parse_args([])
|
args = parser.parse_args([])
|
||||||
self.assertEqual(args.foo, "toto")
|
self.assertEqual(args.foo, "toto")
|
||||||
enum_ex = parser.parse_args_into_dataclasses([])[0]
|
enum_ex = parser.parse_args_into_dataclasses([])[0]
|
||||||
self.assertEqual(enum_ex.foo, BasicEnum.toto)
|
self.assertEqual(enum_ex.foo, MixedTypeEnum.toto)
|
||||||
|
|
||||||
args = parser.parse_args(["--foo", "titi"])
|
args = parser.parse_args(["--foo", "titi"])
|
||||||
self.assertEqual(args.foo, "titi")
|
self.assertEqual(args.foo, "titi")
|
||||||
enum_ex = parser.parse_args_into_dataclasses(["--foo", "titi"])[0]
|
enum_ex = parser.parse_args_into_dataclasses(["--foo", "titi"])[0]
|
||||||
self.assertEqual(enum_ex.foo, BasicEnum.titi)
|
self.assertEqual(enum_ex.foo, MixedTypeEnum.titi)
|
||||||
|
|
||||||
|
args = parser.parse_args(["--foo", "42"])
|
||||||
|
self.assertEqual(args.foo, 42)
|
||||||
|
enum_ex = parser.parse_args_into_dataclasses(["--foo", "42"])[0]
|
||||||
|
self.assertEqual(enum_ex.foo, MixedTypeEnum.fourtytwo)
|
||||||
|
|
||||||
|
def test_with_literal(self):
|
||||||
|
@dataclass
|
||||||
|
class LiteralExample:
|
||||||
|
foo: Literal["titi", "toto", 42] = "toto"
|
||||||
|
|
||||||
|
parser = HfArgumentParser(LiteralExample)
|
||||||
|
|
||||||
|
expected = argparse.ArgumentParser()
|
||||||
|
expected.add_argument(
|
||||||
|
"--foo",
|
||||||
|
default="toto",
|
||||||
|
choices=("titi", "toto", 42),
|
||||||
|
type=make_choice_type_function(["titi", "toto", 42]),
|
||||||
|
)
|
||||||
|
self.argparsersEqual(parser, expected)
|
||||||
|
|
||||||
|
args = parser.parse_args([])
|
||||||
|
self.assertEqual(args.foo, "toto")
|
||||||
|
|
||||||
|
args = parser.parse_args(["--foo", "titi"])
|
||||||
|
self.assertEqual(args.foo, "titi")
|
||||||
|
|
||||||
|
args = parser.parse_args(["--foo", "42"])
|
||||||
|
self.assertEqual(args.foo, 42)
|
||||||
|
|
||||||
def test_with_list(self):
|
def test_with_list(self):
|
||||||
parser = HfArgumentParser(ListExample)
|
parser = HfArgumentParser(ListExample)
|
||||||
@@ -222,7 +287,12 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||||||
expected = argparse.ArgumentParser()
|
expected = argparse.ArgumentParser()
|
||||||
expected.add_argument("--required_list", nargs="+", type=int, required=True)
|
expected.add_argument("--required_list", nargs="+", type=int, required=True)
|
||||||
expected.add_argument("--required_str", type=str, required=True)
|
expected.add_argument("--required_str", type=str, required=True)
|
||||||
expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True)
|
expected.add_argument(
|
||||||
|
"--required_enum",
|
||||||
|
type=make_choice_type_function(["titi", "toto"]),
|
||||||
|
choices=["titi", "toto"],
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
self.argparsersEqual(parser, expected)
|
self.argparsersEqual(parser, expected)
|
||||||
|
|
||||||
def test_with_string_literal_annotation(self):
|
def test_with_string_literal_annotation(self):
|
||||||
@@ -230,7 +300,12 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||||||
|
|
||||||
expected = argparse.ArgumentParser()
|
expected = argparse.ArgumentParser()
|
||||||
expected.add_argument("--foo", type=int, required=True)
|
expected.add_argument("--foo", type=int, required=True)
|
||||||
expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True)
|
expected.add_argument(
|
||||||
|
"--required_enum",
|
||||||
|
type=make_choice_type_function(["titi", "toto"]),
|
||||||
|
choices=["titi", "toto"],
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
expected.add_argument("--opt", type=string_to_bool, default=None)
|
expected.add_argument("--opt", type=string_to_bool, default=None)
|
||||||
expected.add_argument("--baz", default="toto", type=str, help="help message")
|
expected.add_argument("--baz", default="toto", type=str, help="help message")
|
||||||
expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
|
expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
|
||||||
|
|||||||
Reference in New Issue
Block a user