From 838dc06ff5a438159ac25f531d622e8f344476f5 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 31 Jul 2020 14:14:23 +0530 Subject: [PATCH] parse arguments from dict (#4869) * add parse_dict to parse arguments from dict * add unit test for parse_dict --- src/transformers/hf_argparser.py | 13 +++++++++++++ tests/test_hf_argparser.py | 14 ++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index 7d3e2d02e5..6c4e3f204b 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -158,3 +158,16 @@ class HfArgumentParser(ArgumentParser): obj = dtype(**inputs) outputs.append(obj) return (*outputs,) + + def parse_dict(self, args: dict) -> Tuple[DataClass, ...]: + """ + Alternative helper method that does not use `argparse` at all, + instead uses a dict and populating the dataclass types. + """ + outputs = [] + for dtype in self.dataclass_types: + keys = {f.name for f in dataclasses.fields(dtype)} + inputs = {k: v for k, v in args.items() if k in keys} + obj = dtype(**inputs) + outputs.append(obj) + return (*outputs,) diff --git a/tests/test_hf_argparser.py b/tests/test_hf_argparser.py index a3bda37a55..3c219d0b6f 100644 --- a/tests/test_hf_argparser.py +++ b/tests/test_hf_argparser.py @@ -152,6 +152,20 @@ class HfArgumentParserTest(unittest.TestCase): args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split()) self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3])) + def test_parse_dict(self): + parser = HfArgumentParser(BasicExample) + + args_dict = { + "foo": 12, + "bar": 3.14, + "baz": "42", + "flag": True, + } + + parsed_args = parser.parse_dict(args_dict)[0] + args = BasicExample(**args_dict) + self.assertEqual(parsed_args, args) + def test_integration_training_args(self): parser = HfArgumentParser(TrainingArguments) self.assertIsNotNone(parser)