parse arguments from dict (#4869)
* add parse_dict to parse arguments from dict * add unit test for parse_dict
This commit is contained in:
@@ -158,3 +158,16 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
obj = dtype(**inputs)
|
obj = dtype(**inputs)
|
||||||
outputs.append(obj)
|
outputs.append(obj)
|
||||||
return (*outputs,)
|
return (*outputs,)
|
||||||
|
|
||||||
|
def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
|
||||||
|
"""
|
||||||
|
Alternative helper method that does not use `argparse` at all,
|
||||||
|
instead uses a dict and populating the dataclass types.
|
||||||
|
"""
|
||||||
|
outputs = []
|
||||||
|
for dtype in self.dataclass_types:
|
||||||
|
keys = {f.name for f in dataclasses.fields(dtype)}
|
||||||
|
inputs = {k: v for k, v in args.items() if k in keys}
|
||||||
|
obj = dtype(**inputs)
|
||||||
|
outputs.append(obj)
|
||||||
|
return (*outputs,)
|
||||||
|
|||||||
@@ -152,6 +152,20 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||||||
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split())
|
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split())
|
||||||
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3]))
|
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3]))
|
||||||
|
|
||||||
|
def test_parse_dict(self):
|
||||||
|
parser = HfArgumentParser(BasicExample)
|
||||||
|
|
||||||
|
args_dict = {
|
||||||
|
"foo": 12,
|
||||||
|
"bar": 3.14,
|
||||||
|
"baz": "42",
|
||||||
|
"flag": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed_args = parser.parse_dict(args_dict)[0]
|
||||||
|
args = BasicExample(**args_dict)
|
||||||
|
self.assertEqual(parsed_args, args)
|
||||||
|
|
||||||
def test_integration_training_args(self):
|
def test_integration_training_args(self):
|
||||||
parser = HfArgumentParser(TrainingArguments)
|
parser = HfArgumentParser(TrainingArguments)
|
||||||
self.assertIsNotNone(parser)
|
self.assertIsNotNone(parser)
|
||||||
|
|||||||
Reference in New Issue
Block a user