Merge pull request #3934 from huggingface/examples_args_from_files
[qol] example scripts: parse args from .args file or JSON
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
import dataclasses
|
||||
import json
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, NewType, Tuple, Union
|
||||
|
||||
|
||||
@@ -8,6 +11,10 @@ DataClass = NewType("DataClass", Any)
|
||||
DataClassType = NewType("DataClassType", Any)
|
||||
|
||||
|
||||
def trim_suffix(s: str, suffix: str):
|
||||
return s if not s.endswith(suffix) or len(suffix) == 0 else s[: -len(suffix)]
|
||||
|
||||
|
||||
class HfArgumentParser(ArgumentParser):
|
||||
"""
|
||||
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses
|
||||
@@ -70,7 +77,9 @@ class HfArgumentParser(ArgumentParser):
|
||||
kwargs["required"] = True
|
||||
self.add_argument(field_name, **kwargs)
|
||||
|
||||
def parse_args_into_dataclasses(self, args=None, return_remaining_strings=False) -> Tuple[DataClass, ...]:
|
||||
def parse_args_into_dataclasses(
|
||||
self, args=None, return_remaining_strings=False, look_for_args_file=True
|
||||
) -> Tuple[DataClass, ...]:
|
||||
"""
|
||||
Parse command-line args into instances of the specified dataclass types.
|
||||
|
||||
@@ -84,6 +93,10 @@ class HfArgumentParser(ArgumentParser):
|
||||
(same as argparse.ArgumentParser)
|
||||
return_remaining_strings:
|
||||
If true, also return a list of remaining argument strings.
|
||||
look_for_args_file:
|
||||
If true, will look for a ".args" file with the same base name
|
||||
as the entry point script for this process, and will append its
|
||||
potential content to the command line args.
|
||||
|
||||
Returns:
|
||||
Tuple consisting of:
|
||||
@@ -95,6 +108,14 @@ class HfArgumentParser(ArgumentParser):
|
||||
- The potential list of remaining argument strings.
|
||||
(same as argparse.ArgumentParser.parse_known_args)
|
||||
"""
|
||||
if look_for_args_file and len(sys.argv):
|
||||
basename = trim_suffix(sys.argv[0], ".py")
|
||||
args_file = Path(f"{basename}.args")
|
||||
if args_file.exists():
|
||||
fargs = args_file.read_text().split()
|
||||
args = fargs + args if args is not None else fargs + sys.argv[1:]
|
||||
# in case of duplicate arguments the first one has precedence
|
||||
# so we append rather than prepend.
|
||||
namespace, remaining_args = self.parse_known_args(args=args)
|
||||
outputs = []
|
||||
for dtype in self.dataclass_types:
|
||||
@@ -111,3 +132,17 @@ class HfArgumentParser(ArgumentParser):
|
||||
return (*outputs, remaining_args)
|
||||
else:
|
||||
return (*outputs,)
|
||||
|
||||
def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
|
||||
"""
|
||||
Alternative helper method that does not use `argparse` at all,
|
||||
instead loading a json file and populating the dataclass types.
|
||||
"""
|
||||
data = json.loads(Path(json_file).read_text())
|
||||
outputs = []
|
||||
for dtype in self.dataclass_types:
|
||||
keys = {f.name for f in dataclasses.fields(dtype)}
|
||||
inputs = {k: v for k, v in data.items() if k in keys}
|
||||
obj = dtype(**inputs)
|
||||
outputs.append(obj)
|
||||
return (*outputs,)
|
||||
|
||||
Reference in New Issue
Block a user