Fix default bool in argparser (#12424)
* Fix default bool in argparser * Add more to test
This commit is contained in:
@@ -112,8 +112,8 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
# Hack because type=bool in argparse does not behave as we want.
|
# Hack because type=bool in argparse does not behave as we want.
|
||||||
kwargs["type"] = string_to_bool
|
kwargs["type"] = string_to_bool
|
||||||
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
|
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
|
||||||
# Default value is True if we have no default when of type bool.
|
# Default value is False if we have no default when of type bool.
|
||||||
default = True if field.default is dataclasses.MISSING else field.default
|
default = False if field.default is dataclasses.MISSING else field.default
|
||||||
# This is the value that will get picked if we don't include --field_name in any way
|
# This is the value that will get picked if we don't include --field_name in any way
|
||||||
kwargs["default"] = default
|
kwargs["default"] = default
|
||||||
# This tells argparse we accept 0 or 1 value after --field_name
|
# This tells argparse we accept 0 or 1 value after --field_name
|
||||||
|
|||||||
@@ -106,9 +106,13 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||||||
expected.add_argument("--foo", type=int, required=True)
|
expected.add_argument("--foo", type=int, required=True)
|
||||||
expected.add_argument("--bar", type=float, required=True)
|
expected.add_argument("--bar", type=float, required=True)
|
||||||
expected.add_argument("--baz", type=str, required=True)
|
expected.add_argument("--baz", type=str, required=True)
|
||||||
expected.add_argument("--flag", type=string_to_bool, default=True, const=True, nargs="?")
|
expected.add_argument("--flag", type=string_to_bool, default=False, const=True, nargs="?")
|
||||||
self.argparsersEqual(parser, expected)
|
self.argparsersEqual(parser, expected)
|
||||||
|
|
||||||
|
args = ["--foo", "1", "--baz", "quux", "--bar", "0.5"]
|
||||||
|
(example,) = parser.parse_args_into_dataclasses(args, look_for_args_file=False)
|
||||||
|
self.assertFalse(example.flag)
|
||||||
|
|
||||||
def test_with_default(self):
|
def test_with_default(self):
|
||||||
parser = HfArgumentParser(WithDefaultExample)
|
parser = HfArgumentParser(WithDefaultExample)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user