diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index 4b5548fffb..d03ff7004f 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -138,7 +138,14 @@ class HfArgumentParser(ArgumentParser): @staticmethod def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field): - field_name = f"--{field.name}" + # Long-option strings are conventionlly separated by hyphens rather + # than underscores, e.g., "--long-format" rather than "--long_format". + # Argparse converts hyphens to underscores so that the destination + # string is a valid attribute name. Hf_argparser should do the same. + long_options = [f"--{field.name}"] + if "_" in field.name: + long_options.append(f"--{field.name.replace('_', '-')}") + kwargs = field.metadata.copy() # field.metadata is not used at all by Data Classes, # it is provided as a third-party extension mechanism. @@ -198,11 +205,11 @@ class HfArgumentParser(ArgumentParser): if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING): # Default value is False if we have no default when of type bool. default = False if field.default is dataclasses.MISSING else field.default - # This is the value that will get picked if we don't include --field_name in any way + # This is the value that will get picked if we don't include --{field.name} in any way kwargs["default"] = default - # This tells argparse we accept 0 or 1 value after --field_name + # This tells argparse we accept 0 or 1 value after --{field.name} kwargs["nargs"] = "?" - # This is the value that will get picked if we do --field_name (without value) + # This is the value that will get picked if we do --{field.name} (without value) kwargs["const"] = True elif isclass(origin_type) and issubclass(origin_type, list): kwargs["type"] = field.type.__args__[0] @@ -219,7 +226,7 @@ class HfArgumentParser(ArgumentParser): kwargs["default"] = field.default_factory() else: kwargs["required"] = True - parser.add_argument(field_name, *aliases, **kwargs) + parser.add_argument(*long_options, *aliases, **kwargs) # Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added. # Order is important for arguments with the same destination! @@ -227,7 +234,13 @@ class HfArgumentParser(ArgumentParser): # here and we do not need those changes/additional keys. if field.default is True and (field.type is bool or field.type == Optional[bool]): bool_kwargs["default"] = False - parser.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **bool_kwargs) + parser.add_argument( + f"--no_{field.name}", + f"--no-{field.name.replace('_', '-')}", + action="store_false", + dest=field.name, + **bool_kwargs, + ) def _add_dataclass_arguments(self, dtype: DataClassType): if hasattr(dtype, "_argument_group_name"): diff --git a/tests/utils/test_hf_argparser.py b/tests/utils/test_hf_argparser.py index e075cdae16..08c730f734 100644 --- a/tests/utils/test_hf_argparser.py +++ b/tests/utils/test_hf_argparser.py @@ -189,7 +189,7 @@ class HfArgumentParserTest(unittest.TestCase): expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?") # A boolean no_* argument always has to come after its "default: True" regular counter-part # and its default must be set to False - expected.add_argument("--no_baz", action="store_false", default=False, dest="baz") + expected.add_argument("--no_baz", "--no-baz", action="store_false", default=False, dest="baz") expected.add_argument("--opt", type=string_to_bool, default=None) dataclass_types = [WithDefaultBoolExample] @@ -206,6 +206,9 @@ class HfArgumentParserTest(unittest.TestCase): args = parser.parse_args(["--foo", "--no_baz"]) self.assertEqual(args, Namespace(foo=True, baz=False, opt=None)) + args = parser.parse_args(["--foo", "--no-baz"]) + self.assertEqual(args, Namespace(foo=True, baz=False, opt=None)) + args = parser.parse_args(["--foo", "--baz"]) self.assertEqual(args, Namespace(foo=True, baz=True, opt=None)) @@ -271,10 +274,10 @@ class HfArgumentParserTest(unittest.TestCase): parser = HfArgumentParser(ListExample) expected = argparse.ArgumentParser() - expected.add_argument("--foo_int", nargs="+", default=[], type=int) - expected.add_argument("--bar_int", nargs="+", default=[1, 2, 3], type=int) - expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str) - expected.add_argument("--foo_float", nargs="+", default=[0.1, 0.2, 0.3], type=float) + expected.add_argument("--foo_int", "--foo-int", nargs="+", default=[], type=int) + expected.add_argument("--bar_int", "--bar-int", nargs="+", default=[1, 2, 3], type=int) + expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str) + expected.add_argument("--foo_float", "--foo-float", nargs="+", default=[0.1, 0.2, 0.3], type=float) self.argparsersEqual(parser, expected) @@ -287,6 +290,9 @@ class HfArgumentParserTest(unittest.TestCase): args = parser.parse_args("--foo_int 1 --bar_int 2 3 --foo_str a b c --foo_float 0.1 0.7".split()) self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7])) + args = parser.parse_args("--foo-int 1 --bar-int 2 3 --foo-str a b c --foo-float 0.1 0.7".split()) + self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7])) + def test_with_optional(self): expected = argparse.ArgumentParser() expected.add_argument("--foo", default=None, type=int) @@ -314,10 +320,11 @@ class HfArgumentParserTest(unittest.TestCase): parser = HfArgumentParser(RequiredExample) expected = argparse.ArgumentParser() - expected.add_argument("--required_list", nargs="+", type=int, required=True) - expected.add_argument("--required_str", type=str, required=True) + expected.add_argument("--required_list", "--required-list", nargs="+", type=int, required=True) + expected.add_argument("--required_str", "--required-str", type=str, required=True) expected.add_argument( "--required_enum", + "--required-enum", type=make_choice_type_function(["titi", "toto"]), choices=["titi", "toto"], required=True, @@ -331,13 +338,14 @@ class HfArgumentParserTest(unittest.TestCase): expected.add_argument("--foo", type=int, required=True) expected.add_argument( "--required_enum", + "--required-enum", type=make_choice_type_function(["titi", "toto"]), choices=["titi", "toto"], required=True, ) expected.add_argument("--opt", type=string_to_bool, default=None) expected.add_argument("--baz", default="toto", type=str, help="help message") - expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str) + expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str) self.argparsersEqual(parser, expected) def test_parse_dict(self):