From c9486fd0f515c38b0a525ceb5348c4b8bf2d4d9c Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 30 Jun 2021 07:57:05 -0400 Subject: [PATCH] Fix default bool in argparser (#12424) * Fix default bool in argparser * Add more to test --- src/transformers/hf_argparser.py | 4 ++-- tests/test_hf_argparser.py | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index 1763622504..b6f23ec4e2 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -112,8 +112,8 @@ class HfArgumentParser(ArgumentParser): # Hack because type=bool in argparse does not behave as we want. kwargs["type"] = string_to_bool if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING): - # Default value is True if we have no default when of type bool. - default = True if field.default is dataclasses.MISSING else field.default + # Default value is False if we have no default when of type bool. + default = False if field.default is dataclasses.MISSING else field.default # This is the value that will get picked if we don't include --field_name in any way kwargs["default"] = default # This tells argparse we accept 0 or 1 value after --field_name diff --git a/tests/test_hf_argparser.py b/tests/test_hf_argparser.py index 787990b866..44a52035dd 100644 --- a/tests/test_hf_argparser.py +++ b/tests/test_hf_argparser.py @@ -106,9 +106,13 @@ class HfArgumentParserTest(unittest.TestCase): expected.add_argument("--foo", type=int, required=True) expected.add_argument("--bar", type=float, required=True) expected.add_argument("--baz", type=str, required=True) - expected.add_argument("--flag", type=string_to_bool, default=True, const=True, nargs="?") + expected.add_argument("--flag", type=string_to_bool, default=False, const=True, nargs="?") self.argparsersEqual(parser, expected) + args = ["--foo", "1", "--baz", "quux", "--bar", "0.5"] + (example,) = parser.parse_args_into_dataclasses(args, look_for_args_file=False) + self.assertFalse(example.flag) + def test_with_default(self): parser = HfArgumentParser(WithDefaultExample)