HfArgumentParser: allow for hyhenated field names in long-options (#33990)

Allow for hyphenated field names in long-options

argparse converts hyphens into underscores before assignment (e.g., an
option passed as `--long-option` will be stored under `long_option`), So
there is no need to pass options as literal attributes, as in
`--long_option` (with an underscore instead of a hyphen). This commit
ensures that this behavior is respected by `parse_args_into_dataclasses`
as well.

Issue: #33933

Co-authored-by: Daniel Marti <mrtidm@amazon.com>
This commit is contained in:
Dani Martí
2024-10-10 02:58:26 -07:00
committed by GitHub
parent adea67541a
commit a84c413773
2 changed files with 35 additions and 14 deletions

View File

@@ -189,7 +189,7 @@ class HfArgumentParserTest(unittest.TestCase):
expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?")
# A boolean no_* argument always has to come after its "default: True" regular counter-part
# and its default must be set to False
expected.add_argument("--no_baz", action="store_false", default=False, dest="baz")
expected.add_argument("--no_baz", "--no-baz", action="store_false", default=False, dest="baz")
expected.add_argument("--opt", type=string_to_bool, default=None)
dataclass_types = [WithDefaultBoolExample]
@@ -206,6 +206,9 @@ class HfArgumentParserTest(unittest.TestCase):
args = parser.parse_args(["--foo", "--no_baz"])
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
args = parser.parse_args(["--foo", "--no-baz"])
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
args = parser.parse_args(["--foo", "--baz"])
self.assertEqual(args, Namespace(foo=True, baz=True, opt=None))
@@ -271,10 +274,10 @@ class HfArgumentParserTest(unittest.TestCase):
parser = HfArgumentParser(ListExample)
expected = argparse.ArgumentParser()
expected.add_argument("--foo_int", nargs="+", default=[], type=int)
expected.add_argument("--bar_int", nargs="+", default=[1, 2, 3], type=int)
expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
expected.add_argument("--foo_float", nargs="+", default=[0.1, 0.2, 0.3], type=float)
expected.add_argument("--foo_int", "--foo-int", nargs="+", default=[], type=int)
expected.add_argument("--bar_int", "--bar-int", nargs="+", default=[1, 2, 3], type=int)
expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
expected.add_argument("--foo_float", "--foo-float", nargs="+", default=[0.1, 0.2, 0.3], type=float)
self.argparsersEqual(parser, expected)
@@ -287,6 +290,9 @@ class HfArgumentParserTest(unittest.TestCase):
args = parser.parse_args("--foo_int 1 --bar_int 2 3 --foo_str a b c --foo_float 0.1 0.7".split())
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
args = parser.parse_args("--foo-int 1 --bar-int 2 3 --foo-str a b c --foo-float 0.1 0.7".split())
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
def test_with_optional(self):
expected = argparse.ArgumentParser()
expected.add_argument("--foo", default=None, type=int)
@@ -314,10 +320,11 @@ class HfArgumentParserTest(unittest.TestCase):
parser = HfArgumentParser(RequiredExample)
expected = argparse.ArgumentParser()
expected.add_argument("--required_list", nargs="+", type=int, required=True)
expected.add_argument("--required_str", type=str, required=True)
expected.add_argument("--required_list", "--required-list", nargs="+", type=int, required=True)
expected.add_argument("--required_str", "--required-str", type=str, required=True)
expected.add_argument(
"--required_enum",
"--required-enum",
type=make_choice_type_function(["titi", "toto"]),
choices=["titi", "toto"],
required=True,
@@ -331,13 +338,14 @@ class HfArgumentParserTest(unittest.TestCase):
expected.add_argument("--foo", type=int, required=True)
expected.add_argument(
"--required_enum",
"--required-enum",
type=make_choice_type_function(["titi", "toto"]),
choices=["titi", "toto"],
required=True,
)
expected.add_argument("--opt", type=string_to_bool, default=None)
expected.add_argument("--baz", default="toto", type=str, help="help message")
expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
self.argparsersEqual(parser, expected)
def test_parse_dict(self):