From b8686174be75220d2c26a961597a39ef4921b616 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Thu, 30 Apr 2020 22:40:13 -0400 Subject: [PATCH] Merge pull request #3934 from huggingface/examples_args_from_files [qol] example scripts: parse args from .args file or JSON --- examples/run_glue.py | 19 +++++++++------- src/transformers/hf_argparser.py | 37 +++++++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 9 deletions(-) diff --git a/examples/run_glue.py b/examples/run_glue.py index bbb4be1850..e58eb01211 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -19,18 +19,15 @@ import dataclasses import logging import os +import sys from dataclasses import dataclass, field from typing import Dict, Optional import numpy as np +from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GlueDataset +from transformers import GlueDataTrainingArguments as DataTrainingArguments from transformers import ( - AutoConfig, - AutoModelForSequenceClassification, - AutoTokenizer, - EvalPrediction, - GlueDataset, - GlueDataTrainingArguments, HfArgumentParser, Trainer, TrainingArguments, @@ -69,8 +66,14 @@ def main(): # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. - parser = HfArgumentParser((ModelArguments, GlueDataTrainingArguments, TrainingArguments)) - model_args, data_args, training_args = parser.parse_args_into_dataclasses() + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() if ( os.path.exists(training_args.output_dir) diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index 35dc83d7ca..8bb0ddd57d 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -1,6 +1,9 @@ import dataclasses +import json +import sys from argparse import ArgumentParser from enum import Enum +from pathlib import Path from typing import Any, Iterable, NewType, Tuple, Union @@ -8,6 +11,10 @@ DataClass = NewType("DataClass", Any) DataClassType = NewType("DataClassType", Any) +def trim_suffix(s: str, suffix: str): + return s if not s.endswith(suffix) or len(suffix) == 0 else s[: -len(suffix)] + + class HfArgumentParser(ArgumentParser): """ This subclass of `argparse.ArgumentParser` uses type hints on dataclasses @@ -70,7 +77,9 @@ class HfArgumentParser(ArgumentParser): kwargs["required"] = True self.add_argument(field_name, **kwargs) - def parse_args_into_dataclasses(self, args=None, return_remaining_strings=False) -> Tuple[DataClass, ...]: + def parse_args_into_dataclasses( + self, args=None, return_remaining_strings=False, look_for_args_file=True + ) -> Tuple[DataClass, ...]: """ Parse command-line args into instances of the specified dataclass types. @@ -84,6 +93,10 @@ class HfArgumentParser(ArgumentParser): (same as argparse.ArgumentParser) return_remaining_strings: If true, also return a list of remaining argument strings. + look_for_args_file: + If true, will look for a ".args" file with the same base name + as the entry point script for this process, and will append its + potential content to the command line args. Returns: Tuple consisting of: @@ -95,6 +108,14 @@ class HfArgumentParser(ArgumentParser): - The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args) """ + if look_for_args_file and len(sys.argv): + basename = trim_suffix(sys.argv[0], ".py") + args_file = Path(f"{basename}.args") + if args_file.exists(): + fargs = args_file.read_text().split() + args = fargs + args if args is not None else fargs + sys.argv[1:] + # in case of duplicate arguments the first one has precedence + # so we append rather than prepend. namespace, remaining_args = self.parse_known_args(args=args) outputs = [] for dtype in self.dataclass_types: @@ -111,3 +132,17 @@ class HfArgumentParser(ArgumentParser): return (*outputs, remaining_args) else: return (*outputs,) + + def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]: + """ + Alternative helper method that does not use `argparse` at all, + instead loading a json file and populating the dataclass types. + """ + data = json.loads(Path(json_file).read_text()) + outputs = [] + for dtype in self.dataclass_types: + keys = {f.name for f in dataclasses.fields(dtype)} + inputs = {k: v for k, v in data.items() if k in keys} + obj = dtype(**inputs) + outputs.append(obj) + return (*outputs,)