From 3f1714f8a79b19188699a4a21d2039df6078e30e Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Mon, 15 Mar 2021 09:27:55 -0400 Subject: [PATCH] Adding required flags to non-default arguments in hf_argparser (#10688) * Adding required flags to non-default arguments. Signed-off-by: Adam Pocock * make style fix. * Update src/transformers/hf_argparser.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/hf_argparser.py | 4 ++++ tests/test_hf_argparser.py | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index 305eed9c61..cb0a5675fa 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -99,6 +99,8 @@ class HfArgumentParser(ArgumentParser): kwargs["type"] = type(kwargs["choices"][0]) if field.default is not dataclasses.MISSING: kwargs["default"] = field.default + else: + kwargs["required"] = True elif field.type is bool or field.type == Optional[bool]: if field.default is True: self.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **kwargs) @@ -124,6 +126,8 @@ class HfArgumentParser(ArgumentParser): ), "{} cannot be a List of mixed types".format(field.name) if field.default_factory is not dataclasses.MISSING: kwargs["default"] = field.default_factory() + elif field.default is dataclasses.MISSING: + kwargs["required"] = True else: kwargs["type"] = field.type if field.default is not dataclasses.MISSING: diff --git a/tests/test_hf_argparser.py b/tests/test_hf_argparser.py index 22493a23b0..787990b866 100644 --- a/tests/test_hf_argparser.py +++ b/tests/test_hf_argparser.py @@ -78,6 +78,16 @@ class ListExample: foo_float: List[float] = list_field(default=[0.1, 0.2, 0.3]) +@dataclass +class RequiredExample: + required_list: List[int] = field() + required_str: str = field() + required_enum: BasicEnum = field() + + def __post_init__(self): + self.required_enum = BasicEnum(self.required_enum) + + class HfArgumentParserTest(unittest.TestCase): def argparsersEqual(self, a: argparse.ArgumentParser, b: argparse.ArgumentParser) -> bool: """ @@ -186,6 +196,15 @@ class HfArgumentParserTest(unittest.TestCase): args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split()) self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3])) + def test_with_required(self): + parser = HfArgumentParser(RequiredExample) + + expected = argparse.ArgumentParser() + expected.add_argument("--required_list", nargs="+", type=int, required=True) + expected.add_argument("--required_str", type=str, required=True) + expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True) + self.argparsersEqual(parser, expected) + def test_parse_dict(self): parser = HfArgumentParser(BasicExample)