Adding required flags to non-default arguments in hf_argparser (#10688)
* Adding required flags to non-default arguments. Signed-off-by: Adam Pocock <adam.pocock@oracle.com> * make style fix. * Update src/transformers/hf_argparser.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -78,6 +78,16 @@ class ListExample:
|
||||
foo_float: List[float] = list_field(default=[0.1, 0.2, 0.3])
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequiredExample:
|
||||
required_list: List[int] = field()
|
||||
required_str: str = field()
|
||||
required_enum: BasicEnum = field()
|
||||
|
||||
def __post_init__(self):
|
||||
self.required_enum = BasicEnum(self.required_enum)
|
||||
|
||||
|
||||
class HfArgumentParserTest(unittest.TestCase):
|
||||
def argparsersEqual(self, a: argparse.ArgumentParser, b: argparse.ArgumentParser) -> bool:
|
||||
"""
|
||||
@@ -186,6 +196,15 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split())
|
||||
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3]))
|
||||
|
||||
def test_with_required(self):
|
||||
parser = HfArgumentParser(RequiredExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--required_list", nargs="+", type=int, required=True)
|
||||
expected.add_argument("--required_str", type=str, required=True)
|
||||
expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
def test_parse_dict(self):
|
||||
parser = HfArgumentParser(BasicExample)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user