Merge pull request #3934 from huggingface/examples_args_from_files

[qol] example scripts: parse args from .args file or JSON
This commit is contained in:
Julien Chaumond
2020-04-30 22:40:13 -04:00
committed by GitHub
parent f39217a5ec
commit b8686174be
2 changed files with 47 additions and 9 deletions

View File

@@ -1,6 +1,9 @@
import dataclasses
import json
import sys
from argparse import ArgumentParser
from enum import Enum
from pathlib import Path
from typing import Any, Iterable, NewType, Tuple, Union
@@ -8,6 +11,10 @@ DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)
def trim_suffix(s: str, suffix: str):
return s if not s.endswith(suffix) or len(suffix) == 0 else s[: -len(suffix)]
class HfArgumentParser(ArgumentParser):
"""
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses
@@ -70,7 +77,9 @@ class HfArgumentParser(ArgumentParser):
kwargs["required"] = True
self.add_argument(field_name, **kwargs)
def parse_args_into_dataclasses(self, args=None, return_remaining_strings=False) -> Tuple[DataClass, ...]:
def parse_args_into_dataclasses(
self, args=None, return_remaining_strings=False, look_for_args_file=True
) -> Tuple[DataClass, ...]:
"""
Parse command-line args into instances of the specified dataclass types.
@@ -84,6 +93,10 @@ class HfArgumentParser(ArgumentParser):
(same as argparse.ArgumentParser)
return_remaining_strings:
If true, also return a list of remaining argument strings.
look_for_args_file:
If true, will look for a ".args" file with the same base name
as the entry point script for this process, and will append its
potential content to the command line args.
Returns:
Tuple consisting of:
@@ -95,6 +108,14 @@ class HfArgumentParser(ArgumentParser):
- The potential list of remaining argument strings.
(same as argparse.ArgumentParser.parse_known_args)
"""
if look_for_args_file and len(sys.argv):
basename = trim_suffix(sys.argv[0], ".py")
args_file = Path(f"{basename}.args")
if args_file.exists():
fargs = args_file.read_text().split()
args = fargs + args if args is not None else fargs + sys.argv[1:]
# in case of duplicate arguments the first one has precedence
# so we append rather than prepend.
namespace, remaining_args = self.parse_known_args(args=args)
outputs = []
for dtype in self.dataclass_types:
@@ -111,3 +132,17 @@ class HfArgumentParser(ArgumentParser):
return (*outputs, remaining_args)
else:
return (*outputs,)
def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all,
instead loading a json file and populating the dataclass types.
"""
data = json.loads(Path(json_file).read_text())
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype)}
inputs = {k: v for k, v in data.items() if k in keys}
obj = dtype(**inputs)
outputs.append(obj)
return (*outputs,)