Merge pull request #3934 from huggingface/examples_args_from_files
[qol] example scripts: parse args from .args file or JSON
This commit is contained in:
@@ -19,18 +19,15 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GlueDataset
|
||||||
|
from transformers import GlueDataTrainingArguments as DataTrainingArguments
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
|
||||||
AutoModelForSequenceClassification,
|
|
||||||
AutoTokenizer,
|
|
||||||
EvalPrediction,
|
|
||||||
GlueDataset,
|
|
||||||
GlueDataTrainingArguments,
|
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
@@ -69,8 +66,14 @@ def main():
|
|||||||
# or by passing the --help flag to this script.
|
# or by passing the --help flag to this script.
|
||||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||||
|
|
||||||
parser = HfArgumentParser((ModelArguments, GlueDataTrainingArguments, TrainingArguments))
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
|
||||||
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||||
|
# If we pass only one argument to the script and it's the path to a json file,
|
||||||
|
# let's parse it to get our arguments.
|
||||||
|
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||||
|
else:
|
||||||
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
os.path.exists(training_args.output_dir)
|
os.path.exists(training_args.output_dir)
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Iterable, NewType, Tuple, Union
|
from typing import Any, Iterable, NewType, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
@@ -8,6 +11,10 @@ DataClass = NewType("DataClass", Any)
|
|||||||
DataClassType = NewType("DataClassType", Any)
|
DataClassType = NewType("DataClassType", Any)
|
||||||
|
|
||||||
|
|
||||||
|
def trim_suffix(s: str, suffix: str):
|
||||||
|
return s if not s.endswith(suffix) or len(suffix) == 0 else s[: -len(suffix)]
|
||||||
|
|
||||||
|
|
||||||
class HfArgumentParser(ArgumentParser):
|
class HfArgumentParser(ArgumentParser):
|
||||||
"""
|
"""
|
||||||
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses
|
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses
|
||||||
@@ -70,7 +77,9 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
kwargs["required"] = True
|
kwargs["required"] = True
|
||||||
self.add_argument(field_name, **kwargs)
|
self.add_argument(field_name, **kwargs)
|
||||||
|
|
||||||
def parse_args_into_dataclasses(self, args=None, return_remaining_strings=False) -> Tuple[DataClass, ...]:
|
def parse_args_into_dataclasses(
|
||||||
|
self, args=None, return_remaining_strings=False, look_for_args_file=True
|
||||||
|
) -> Tuple[DataClass, ...]:
|
||||||
"""
|
"""
|
||||||
Parse command-line args into instances of the specified dataclass types.
|
Parse command-line args into instances of the specified dataclass types.
|
||||||
|
|
||||||
@@ -84,6 +93,10 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
(same as argparse.ArgumentParser)
|
(same as argparse.ArgumentParser)
|
||||||
return_remaining_strings:
|
return_remaining_strings:
|
||||||
If true, also return a list of remaining argument strings.
|
If true, also return a list of remaining argument strings.
|
||||||
|
look_for_args_file:
|
||||||
|
If true, will look for a ".args" file with the same base name
|
||||||
|
as the entry point script for this process, and will append its
|
||||||
|
potential content to the command line args.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple consisting of:
|
Tuple consisting of:
|
||||||
@@ -95,6 +108,14 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
- The potential list of remaining argument strings.
|
- The potential list of remaining argument strings.
|
||||||
(same as argparse.ArgumentParser.parse_known_args)
|
(same as argparse.ArgumentParser.parse_known_args)
|
||||||
"""
|
"""
|
||||||
|
if look_for_args_file and len(sys.argv):
|
||||||
|
basename = trim_suffix(sys.argv[0], ".py")
|
||||||
|
args_file = Path(f"{basename}.args")
|
||||||
|
if args_file.exists():
|
||||||
|
fargs = args_file.read_text().split()
|
||||||
|
args = fargs + args if args is not None else fargs + sys.argv[1:]
|
||||||
|
# in case of duplicate arguments the first one has precedence
|
||||||
|
# so we append rather than prepend.
|
||||||
namespace, remaining_args = self.parse_known_args(args=args)
|
namespace, remaining_args = self.parse_known_args(args=args)
|
||||||
outputs = []
|
outputs = []
|
||||||
for dtype in self.dataclass_types:
|
for dtype in self.dataclass_types:
|
||||||
@@ -111,3 +132,17 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
return (*outputs, remaining_args)
|
return (*outputs, remaining_args)
|
||||||
else:
|
else:
|
||||||
return (*outputs,)
|
return (*outputs,)
|
||||||
|
|
||||||
|
def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
|
||||||
|
"""
|
||||||
|
Alternative helper method that does not use `argparse` at all,
|
||||||
|
instead loading a json file and populating the dataclass types.
|
||||||
|
"""
|
||||||
|
data = json.loads(Path(json_file).read_text())
|
||||||
|
outputs = []
|
||||||
|
for dtype in self.dataclass_types:
|
||||||
|
keys = {f.name for f in dataclasses.fields(dtype)}
|
||||||
|
inputs = {k: v for k, v in data.items() if k in keys}
|
||||||
|
obj = dtype(**inputs)
|
||||||
|
outputs.append(obj)
|
||||||
|
return (*outputs,)
|
||||||
|
|||||||
Reference in New Issue
Block a user