Adding required flags to non-default arguments in hf_argparser (#10688)

* Adding required flags to non-default arguments.

Signed-off-by: Adam Pocock <adam.pocock@oracle.com>

* make style fix.

* Update src/transformers/hf_argparser.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Adam Pocock
2021-03-15 09:27:55 -04:00
committed by GitHub
parent 6f840990a7
commit 3f1714f8a7
2 changed files with 23 additions and 0 deletions

View File

@@ -78,6 +78,16 @@ class ListExample:
foo_float: List[float] = list_field(default=[0.1, 0.2, 0.3])
@dataclass
class RequiredExample:
required_list: List[int] = field()
required_str: str = field()
required_enum: BasicEnum = field()
def __post_init__(self):
self.required_enum = BasicEnum(self.required_enum)
class HfArgumentParserTest(unittest.TestCase):
def argparsersEqual(self, a: argparse.ArgumentParser, b: argparse.ArgumentParser) -> bool:
"""
@@ -186,6 +196,15 @@ class HfArgumentParserTest(unittest.TestCase):
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split())
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3]))
def test_with_required(self):
parser = HfArgumentParser(RequiredExample)
expected = argparse.ArgumentParser()
expected.add_argument("--required_list", nargs="+", type=int, required=True)
expected.add_argument("--required_str", type=str, required=True)
expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True)
self.argparsersEqual(parser, expected)
def test_parse_dict(self):
parser = HfArgumentParser(BasicExample)