HfArgumentParser: allow for hyhenated field names in long-options (#33990)
Allow for hyphenated field names in long-options argparse converts hyphens into underscores before assignment (e.g., an option passed as `--long-option` will be stored under `long_option`), So there is no need to pass options as literal attributes, as in `--long_option` (with an underscore instead of a hyphen). This commit ensures that this behavior is respected by `parse_args_into_dataclasses` as well. Issue: #33933 Co-authored-by: Daniel Marti <mrtidm@amazon.com>
This commit is contained in:
@@ -138,7 +138,14 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
|
def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
|
||||||
field_name = f"--{field.name}"
|
# Long-option strings are conventionlly separated by hyphens rather
|
||||||
|
# than underscores, e.g., "--long-format" rather than "--long_format".
|
||||||
|
# Argparse converts hyphens to underscores so that the destination
|
||||||
|
# string is a valid attribute name. Hf_argparser should do the same.
|
||||||
|
long_options = [f"--{field.name}"]
|
||||||
|
if "_" in field.name:
|
||||||
|
long_options.append(f"--{field.name.replace('_', '-')}")
|
||||||
|
|
||||||
kwargs = field.metadata.copy()
|
kwargs = field.metadata.copy()
|
||||||
# field.metadata is not used at all by Data Classes,
|
# field.metadata is not used at all by Data Classes,
|
||||||
# it is provided as a third-party extension mechanism.
|
# it is provided as a third-party extension mechanism.
|
||||||
@@ -198,11 +205,11 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
|
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
|
||||||
# Default value is False if we have no default when of type bool.
|
# Default value is False if we have no default when of type bool.
|
||||||
default = False if field.default is dataclasses.MISSING else field.default
|
default = False if field.default is dataclasses.MISSING else field.default
|
||||||
# This is the value that will get picked if we don't include --field_name in any way
|
# This is the value that will get picked if we don't include --{field.name} in any way
|
||||||
kwargs["default"] = default
|
kwargs["default"] = default
|
||||||
# This tells argparse we accept 0 or 1 value after --field_name
|
# This tells argparse we accept 0 or 1 value after --{field.name}
|
||||||
kwargs["nargs"] = "?"
|
kwargs["nargs"] = "?"
|
||||||
# This is the value that will get picked if we do --field_name (without value)
|
# This is the value that will get picked if we do --{field.name} (without value)
|
||||||
kwargs["const"] = True
|
kwargs["const"] = True
|
||||||
elif isclass(origin_type) and issubclass(origin_type, list):
|
elif isclass(origin_type) and issubclass(origin_type, list):
|
||||||
kwargs["type"] = field.type.__args__[0]
|
kwargs["type"] = field.type.__args__[0]
|
||||||
@@ -219,7 +226,7 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
kwargs["default"] = field.default_factory()
|
kwargs["default"] = field.default_factory()
|
||||||
else:
|
else:
|
||||||
kwargs["required"] = True
|
kwargs["required"] = True
|
||||||
parser.add_argument(field_name, *aliases, **kwargs)
|
parser.add_argument(*long_options, *aliases, **kwargs)
|
||||||
|
|
||||||
# Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added.
|
# Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added.
|
||||||
# Order is important for arguments with the same destination!
|
# Order is important for arguments with the same destination!
|
||||||
@@ -227,7 +234,13 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
# here and we do not need those changes/additional keys.
|
# here and we do not need those changes/additional keys.
|
||||||
if field.default is True and (field.type is bool or field.type == Optional[bool]):
|
if field.default is True and (field.type is bool or field.type == Optional[bool]):
|
||||||
bool_kwargs["default"] = False
|
bool_kwargs["default"] = False
|
||||||
parser.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **bool_kwargs)
|
parser.add_argument(
|
||||||
|
f"--no_{field.name}",
|
||||||
|
f"--no-{field.name.replace('_', '-')}",
|
||||||
|
action="store_false",
|
||||||
|
dest=field.name,
|
||||||
|
**bool_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def _add_dataclass_arguments(self, dtype: DataClassType):
|
def _add_dataclass_arguments(self, dtype: DataClassType):
|
||||||
if hasattr(dtype, "_argument_group_name"):
|
if hasattr(dtype, "_argument_group_name"):
|
||||||
|
|||||||
@@ -189,7 +189,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||||||
expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?")
|
expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?")
|
||||||
# A boolean no_* argument always has to come after its "default: True" regular counter-part
|
# A boolean no_* argument always has to come after its "default: True" regular counter-part
|
||||||
# and its default must be set to False
|
# and its default must be set to False
|
||||||
expected.add_argument("--no_baz", action="store_false", default=False, dest="baz")
|
expected.add_argument("--no_baz", "--no-baz", action="store_false", default=False, dest="baz")
|
||||||
expected.add_argument("--opt", type=string_to_bool, default=None)
|
expected.add_argument("--opt", type=string_to_bool, default=None)
|
||||||
|
|
||||||
dataclass_types = [WithDefaultBoolExample]
|
dataclass_types = [WithDefaultBoolExample]
|
||||||
@@ -206,6 +206,9 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||||||
args = parser.parse_args(["--foo", "--no_baz"])
|
args = parser.parse_args(["--foo", "--no_baz"])
|
||||||
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
|
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
|
||||||
|
|
||||||
|
args = parser.parse_args(["--foo", "--no-baz"])
|
||||||
|
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
|
||||||
|
|
||||||
args = parser.parse_args(["--foo", "--baz"])
|
args = parser.parse_args(["--foo", "--baz"])
|
||||||
self.assertEqual(args, Namespace(foo=True, baz=True, opt=None))
|
self.assertEqual(args, Namespace(foo=True, baz=True, opt=None))
|
||||||
|
|
||||||
@@ -271,10 +274,10 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||||||
parser = HfArgumentParser(ListExample)
|
parser = HfArgumentParser(ListExample)
|
||||||
|
|
||||||
expected = argparse.ArgumentParser()
|
expected = argparse.ArgumentParser()
|
||||||
expected.add_argument("--foo_int", nargs="+", default=[], type=int)
|
expected.add_argument("--foo_int", "--foo-int", nargs="+", default=[], type=int)
|
||||||
expected.add_argument("--bar_int", nargs="+", default=[1, 2, 3], type=int)
|
expected.add_argument("--bar_int", "--bar-int", nargs="+", default=[1, 2, 3], type=int)
|
||||||
expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
|
expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
|
||||||
expected.add_argument("--foo_float", nargs="+", default=[0.1, 0.2, 0.3], type=float)
|
expected.add_argument("--foo_float", "--foo-float", nargs="+", default=[0.1, 0.2, 0.3], type=float)
|
||||||
|
|
||||||
self.argparsersEqual(parser, expected)
|
self.argparsersEqual(parser, expected)
|
||||||
|
|
||||||
@@ -287,6 +290,9 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||||||
args = parser.parse_args("--foo_int 1 --bar_int 2 3 --foo_str a b c --foo_float 0.1 0.7".split())
|
args = parser.parse_args("--foo_int 1 --bar_int 2 3 --foo_str a b c --foo_float 0.1 0.7".split())
|
||||||
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
|
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
|
||||||
|
|
||||||
|
args = parser.parse_args("--foo-int 1 --bar-int 2 3 --foo-str a b c --foo-float 0.1 0.7".split())
|
||||||
|
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
|
||||||
|
|
||||||
def test_with_optional(self):
|
def test_with_optional(self):
|
||||||
expected = argparse.ArgumentParser()
|
expected = argparse.ArgumentParser()
|
||||||
expected.add_argument("--foo", default=None, type=int)
|
expected.add_argument("--foo", default=None, type=int)
|
||||||
@@ -314,10 +320,11 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||||||
parser = HfArgumentParser(RequiredExample)
|
parser = HfArgumentParser(RequiredExample)
|
||||||
|
|
||||||
expected = argparse.ArgumentParser()
|
expected = argparse.ArgumentParser()
|
||||||
expected.add_argument("--required_list", nargs="+", type=int, required=True)
|
expected.add_argument("--required_list", "--required-list", nargs="+", type=int, required=True)
|
||||||
expected.add_argument("--required_str", type=str, required=True)
|
expected.add_argument("--required_str", "--required-str", type=str, required=True)
|
||||||
expected.add_argument(
|
expected.add_argument(
|
||||||
"--required_enum",
|
"--required_enum",
|
||||||
|
"--required-enum",
|
||||||
type=make_choice_type_function(["titi", "toto"]),
|
type=make_choice_type_function(["titi", "toto"]),
|
||||||
choices=["titi", "toto"],
|
choices=["titi", "toto"],
|
||||||
required=True,
|
required=True,
|
||||||
@@ -331,13 +338,14 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||||||
expected.add_argument("--foo", type=int, required=True)
|
expected.add_argument("--foo", type=int, required=True)
|
||||||
expected.add_argument(
|
expected.add_argument(
|
||||||
"--required_enum",
|
"--required_enum",
|
||||||
|
"--required-enum",
|
||||||
type=make_choice_type_function(["titi", "toto"]),
|
type=make_choice_type_function(["titi", "toto"]),
|
||||||
choices=["titi", "toto"],
|
choices=["titi", "toto"],
|
||||||
required=True,
|
required=True,
|
||||||
)
|
)
|
||||||
expected.add_argument("--opt", type=string_to_bool, default=None)
|
expected.add_argument("--opt", type=string_to_bool, default=None)
|
||||||
expected.add_argument("--baz", default="toto", type=str, help="help message")
|
expected.add_argument("--baz", default="toto", type=str, help="help message")
|
||||||
expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
|
expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
|
||||||
self.argparsersEqual(parser, expected)
|
self.argparsersEqual(parser, expected)
|
||||||
|
|
||||||
def test_parse_dict(self):
|
def test_parse_dict(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user