Merge pull request #3934 from huggingface/examples_args_from_files

[qol] example scripts: parse args from .args file or JSON
This commit is contained in:
Julien Chaumond
2020-04-30 22:40:13 -04:00
committed by GitHub
parent f39217a5ec
commit b8686174be
2 changed files with 47 additions and 9 deletions

View File

@@ -19,18 +19,15 @@
import dataclasses import dataclasses
import logging import logging
import os import os
import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, Optional from typing import Dict, Optional
import numpy as np import numpy as np
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GlueDataset
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import ( from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
EvalPrediction,
GlueDataset,
GlueDataTrainingArguments,
HfArgumentParser, HfArgumentParser,
Trainer, Trainer,
TrainingArguments, TrainingArguments,
@@ -69,8 +66,14 @@ def main():
# or by passing the --help flag to this script. # or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns. # We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, GlueDataTrainingArguments, TrainingArguments)) parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( if (
os.path.exists(training_args.output_dir) os.path.exists(training_args.output_dir)

View File

@@ -1,6 +1,9 @@
import dataclasses import dataclasses
import json
import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from enum import Enum from enum import Enum
from pathlib import Path
from typing import Any, Iterable, NewType, Tuple, Union from typing import Any, Iterable, NewType, Tuple, Union
@@ -8,6 +11,10 @@ DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any) DataClassType = NewType("DataClassType", Any)
def trim_suffix(s: str, suffix: str):
return s if not s.endswith(suffix) or len(suffix) == 0 else s[: -len(suffix)]
class HfArgumentParser(ArgumentParser): class HfArgumentParser(ArgumentParser):
""" """
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses This subclass of `argparse.ArgumentParser` uses type hints on dataclasses
@@ -70,7 +77,9 @@ class HfArgumentParser(ArgumentParser):
kwargs["required"] = True kwargs["required"] = True
self.add_argument(field_name, **kwargs) self.add_argument(field_name, **kwargs)
def parse_args_into_dataclasses(self, args=None, return_remaining_strings=False) -> Tuple[DataClass, ...]: def parse_args_into_dataclasses(
self, args=None, return_remaining_strings=False, look_for_args_file=True
) -> Tuple[DataClass, ...]:
""" """
Parse command-line args into instances of the specified dataclass types. Parse command-line args into instances of the specified dataclass types.
@@ -84,6 +93,10 @@ class HfArgumentParser(ArgumentParser):
(same as argparse.ArgumentParser) (same as argparse.ArgumentParser)
return_remaining_strings: return_remaining_strings:
If true, also return a list of remaining argument strings. If true, also return a list of remaining argument strings.
look_for_args_file:
If true, will look for a ".args" file with the same base name
as the entry point script for this process, and will append its
potential content to the command line args.
Returns: Returns:
Tuple consisting of: Tuple consisting of:
@@ -95,6 +108,14 @@ class HfArgumentParser(ArgumentParser):
- The potential list of remaining argument strings. - The potential list of remaining argument strings.
(same as argparse.ArgumentParser.parse_known_args) (same as argparse.ArgumentParser.parse_known_args)
""" """
if look_for_args_file and len(sys.argv):
basename = trim_suffix(sys.argv[0], ".py")
args_file = Path(f"{basename}.args")
if args_file.exists():
fargs = args_file.read_text().split()
args = fargs + args if args is not None else fargs + sys.argv[1:]
# in case of duplicate arguments the first one has precedence
# so we append rather than prepend.
namespace, remaining_args = self.parse_known_args(args=args) namespace, remaining_args = self.parse_known_args(args=args)
outputs = [] outputs = []
for dtype in self.dataclass_types: for dtype in self.dataclass_types:
@@ -111,3 +132,17 @@ class HfArgumentParser(ArgumentParser):
return (*outputs, remaining_args) return (*outputs, remaining_args)
else: else:
return (*outputs,) return (*outputs,)
def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all,
instead loading a json file and populating the dataclass types.
"""
data = json.loads(Path(json_file).read_text())
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype)}
inputs = {k: v for k, v in data.items() if k in keys}
obj = dtype(**inputs)
outputs.append(obj)
return (*outputs,)