Add an option to HfArgumentParser.parse_{dict,json_file} to raise an Exception when there extra keys (#18692)
* Update parser to track unneeded keys, off by default * Fix formatting * Fix docstrings and defaults in HfArgparser * Fix formatting
This commit is contained in:
@@ -234,29 +234,60 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
|
|
||||||
return (*outputs,)
|
return (*outputs,)
|
||||||
|
|
||||||
def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
|
def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
|
||||||
"""
|
"""
|
||||||
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
|
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
|
||||||
dataclass types.
|
dataclass types.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_file (`str` or `os.PathLike`):
|
||||||
|
File name of the json file to parse
|
||||||
|
allow_extra_keys (`bool`, *optional*, defaults to `False`):
|
||||||
|
Defaults to False. If False, will raise an exception if the json file contains keys that are not
|
||||||
|
parsed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple consisting of:
|
||||||
|
|
||||||
|
- the dataclass instances in the same order as they were passed to the initializer.
|
||||||
"""
|
"""
|
||||||
data = json.loads(Path(json_file).read_text())
|
data = json.loads(Path(json_file).read_text())
|
||||||
|
unused_keys = set(data.keys())
|
||||||
outputs = []
|
outputs = []
|
||||||
for dtype in self.dataclass_types:
|
for dtype in self.dataclass_types:
|
||||||
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
||||||
inputs = {k: v for k, v in data.items() if k in keys}
|
inputs = {k: v for k, v in data.items() if k in keys}
|
||||||
|
unused_keys.difference_update(inputs.keys())
|
||||||
obj = dtype(**inputs)
|
obj = dtype(**inputs)
|
||||||
outputs.append(obj)
|
outputs.append(obj)
|
||||||
return (*outputs,)
|
if not allow_extra_keys and unused_keys:
|
||||||
|
raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
|
||||||
|
return tuple(outputs)
|
||||||
|
|
||||||
def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
|
def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
|
||||||
"""
|
"""
|
||||||
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
|
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
|
||||||
types.
|
types.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (`dict`):
|
||||||
|
dict containing config values
|
||||||
|
allow_extra_keys (`bool`, *optional*, defaults to `False`):
|
||||||
|
Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple consisting of:
|
||||||
|
|
||||||
|
- the dataclass instances in the same order as they were passed to the initializer.
|
||||||
"""
|
"""
|
||||||
|
unused_keys = set(args.keys())
|
||||||
outputs = []
|
outputs = []
|
||||||
for dtype in self.dataclass_types:
|
for dtype in self.dataclass_types:
|
||||||
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
||||||
inputs = {k: v for k, v in args.items() if k in keys}
|
inputs = {k: v for k, v in args.items() if k in keys}
|
||||||
|
unused_keys.difference_update(inputs.keys())
|
||||||
obj = dtype(**inputs)
|
obj = dtype(**inputs)
|
||||||
outputs.append(obj)
|
outputs.append(obj)
|
||||||
return (*outputs,)
|
if not allow_extra_keys and unused_keys:
|
||||||
|
raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
|
||||||
|
return tuple(outputs)
|
||||||
|
|||||||
@@ -245,6 +245,19 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||||||
args = BasicExample(**args_dict)
|
args = BasicExample(**args_dict)
|
||||||
self.assertEqual(parsed_args, args)
|
self.assertEqual(parsed_args, args)
|
||||||
|
|
||||||
|
def test_parse_dict_extra_key(self):
|
||||||
|
parser = HfArgumentParser(BasicExample)
|
||||||
|
|
||||||
|
args_dict = {
|
||||||
|
"foo": 12,
|
||||||
|
"bar": 3.14,
|
||||||
|
"baz": "42",
|
||||||
|
"flag": True,
|
||||||
|
"extra": 42,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertRaises(ValueError, parser.parse_dict, args_dict, allow_extra_keys=False)
|
||||||
|
|
||||||
def test_integration_training_args(self):
|
def test_integration_training_args(self):
|
||||||
parser = HfArgumentParser(TrainingArguments)
|
parser = HfArgumentParser(TrainingArguments)
|
||||||
self.assertIsNotNone(parser)
|
self.assertIsNotNone(parser)
|
||||||
|
|||||||
Reference in New Issue
Block a user