[examples] Generate argparsers from type hints on dataclasses (#3669)
* [examples] Generate argparsers from type hints on dataclasses * [HfArgumentParser] way simpler API * Restore run_language_modeling.py for easier diff * [HfArgumentParser] final tweaks from code review
This commit is contained in:
@@ -22,6 +22,8 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -36,6 +38,8 @@ from transformers import (
|
|||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
HfArgumentParser,
|
||||||
|
TrainingArguments,
|
||||||
get_linear_schedule_with_warmup,
|
get_linear_schedule_with_warmup,
|
||||||
)
|
)
|
||||||
from transformers import glue_compute_metrics as compute_metrics
|
from transformers import glue_compute_metrics as compute_metrics
|
||||||
@@ -376,137 +380,54 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
|||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def main():
|
@dataclass
|
||||||
parser = argparse.ArgumentParser()
|
class ModelArguments:
|
||||||
|
"""
|
||||||
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||||
|
"""
|
||||||
|
|
||||||
# Required parameters
|
model_name_or_path: str = field(
|
||||||
parser.add_argument(
|
metadata={"help": "Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)}
|
||||||
"--data_dir",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
model_type: str = field(metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_TYPES)})
|
||||||
"--model_type",
|
config_name: Optional[str] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Model type selected in the list: " + ", ".join(MODEL_TYPES),
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
tokenizer_name: Optional[str] = field(
|
||||||
"--model_name_or_path",
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
cache_dir: Optional[str] = field(
|
||||||
"--task_name",
|
default=None, metadata={"help": "Where do you want to store the pre-trained models downloaded from s3"}
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output_dir",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="The output directory where the model predictions and checkpoints will be written.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Other parameters
|
|
||||||
parser.add_argument(
|
@dataclass
|
||||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name",
|
class DataProcessingArguments:
|
||||||
|
task_name: str = field(
|
||||||
|
metadata={"help": "The name of the task to train selected in the list: " + ", ".join(processors.keys())}
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
data_dir: str = field(
|
||||||
"--tokenizer_name",
|
metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
|
||||||
default="",
|
|
||||||
type=str,
|
|
||||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
max_seq_length: int = field(
|
||||||
"--cache_dir",
|
|
||||||
default="",
|
|
||||||
type=str,
|
|
||||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_seq_length",
|
|
||||||
default=128,
|
default=128,
|
||||||
type=int,
|
metadata={
|
||||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded.",
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
overwrite_cache: bool = field(
|
||||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||||
parser.add_argument(
|
|
||||||
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--gradient_accumulation_steps",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
|
||||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
|
||||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
|
||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_steps",
|
|
||||||
default=-1,
|
|
||||||
type=int,
|
|
||||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
|
||||||
|
|
||||||
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
|
def main():
|
||||||
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
|
parser = HfArgumentParser((ModelArguments, DataProcessingArguments, TrainingArguments))
|
||||||
parser.add_argument(
|
model_args, dataprocessing_args, training_args = parser.parse_args_into_dataclasses()
|
||||||
"--eval_all_checkpoints",
|
|
||||||
action="store_true",
|
|
||||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
|
||||||
)
|
|
||||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
|
||||||
parser.add_argument(
|
|
||||||
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets",
|
|
||||||
)
|
|
||||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
|
||||||
|
|
||||||
parser.add_argument(
|
# For now, let's merge all the sets of args into one,
|
||||||
"--fp16",
|
# but soon, we'll keep distinct sets of args, with a cleaner separation of concerns.
|
||||||
action="store_true",
|
args = argparse.Namespace(**vars(model_args), **vars(dataprocessing_args), **vars(training_args))
|
||||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--fp16_opt_level",
|
|
||||||
type=str,
|
|
||||||
default="O1",
|
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
|
||||||
"See details at https://nvidia.github.io/apex/amp.html",
|
|
||||||
)
|
|
||||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
|
||||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
|
||||||
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
os.path.exists(args.output_dir)
|
os.path.exists(args.output_dir)
|
||||||
@@ -515,19 +436,8 @@ def main():
|
|||||||
and not args.overwrite_output_dir
|
and not args.overwrite_output_dir
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
f"Output directory ({args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
||||||
args.output_dir
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Setup distant debugging if needed
|
|
||||||
if args.server_ip and args.server_port:
|
|
||||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
|
||||||
import ptvsd
|
|
||||||
|
|
||||||
print("Waiting for debugger attach")
|
|
||||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
|
||||||
ptvsd.wait_for_attach()
|
|
||||||
|
|
||||||
# Setup CUDA, GPU & distributed training
|
# Setup CUDA, GPU & distributed training
|
||||||
if args.local_rank == -1 or args.no_cuda:
|
if args.local_rank == -1 or args.no_cuda:
|
||||||
@@ -576,18 +486,16 @@ def main():
|
|||||||
args.config_name if args.config_name else args.model_name_or_path,
|
args.config_name if args.config_name else args.model_name_or_path,
|
||||||
num_labels=num_labels,
|
num_labels=num_labels,
|
||||||
finetuning_task=args.task_name,
|
finetuning_task=args.task_name,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
cache_dir=args.cache_dir,
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, cache_dir=args.cache_dir,
|
||||||
do_lower_case=args.do_lower_case,
|
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
|
||||||
)
|
)
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
args.model_name_or_path,
|
args.model_name_or_path,
|
||||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
cache_dir=args.cache_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
@@ -629,7 +537,7 @@ def main():
|
|||||||
# Evaluation
|
# Evaluation
|
||||||
results = {}
|
results = {}
|
||||||
if args.do_eval and args.local_rank in [-1, 0]:
|
if args.do_eval and args.local_rank in [-1, 0]:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
|
||||||
checkpoints = [args.output_dir]
|
checkpoints = [args.output_dir]
|
||||||
if args.eval_all_checkpoints:
|
if args.eval_all_checkpoints:
|
||||||
checkpoints = list(
|
checkpoints = list(
|
||||||
|
|||||||
@@ -88,6 +88,7 @@ from .file_utils import (
|
|||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
)
|
)
|
||||||
|
from .hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
# Model Cards
|
# Model Cards
|
||||||
from .modelcard import ModelCard
|
from .modelcard import ModelCard
|
||||||
@@ -141,6 +142,7 @@ from .tokenization_utils import PreTrainedTokenizer
|
|||||||
from .tokenization_xlm import XLMTokenizer
|
from .tokenization_xlm import XLMTokenizer
|
||||||
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||||
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
|
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
|
||||||
|
from .training_args import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|||||||
113
src/transformers/hf_argparser.py
Normal file
113
src/transformers/hf_argparser.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
import dataclasses
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Iterable, NewType, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
DataClass = NewType("DataClass", Any)
|
||||||
|
DataClassType = NewType("DataClassType", Any)
|
||||||
|
|
||||||
|
|
||||||
|
class HfArgumentParser(ArgumentParser):
|
||||||
|
"""
|
||||||
|
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses
|
||||||
|
to generate arguments.
|
||||||
|
|
||||||
|
The class is designed to play well with the native argparse. In particular,
|
||||||
|
you can add more (non-dataclass backed) arguments to the parser after initialization
|
||||||
|
and you'll get the output back after parsing as an additional namespace.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataclass_types: Iterable[DataClassType]
|
||||||
|
|
||||||
|
def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dataclass_types:
|
||||||
|
Dataclass type, or list of dataclass types for which we will "fill" instances
|
||||||
|
with the parsed args.
|
||||||
|
kwargs:
|
||||||
|
(Optional) Passed to `argparse.ArgumentParser()` in the regular way.
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if dataclasses.is_dataclass(dataclass_types):
|
||||||
|
dataclass_types = [dataclass_types]
|
||||||
|
self.dataclass_types = dataclass_types
|
||||||
|
for dtype in self.dataclass_types:
|
||||||
|
self._add_dataclass_arguments(dtype)
|
||||||
|
|
||||||
|
def _add_dataclass_arguments(self, dtype: DataClassType):
|
||||||
|
for field in dataclasses.fields(dtype):
|
||||||
|
field_name = f"--{field.name}"
|
||||||
|
kwargs = field.metadata.copy()
|
||||||
|
# field.metadata is not used at all by Data Classes,
|
||||||
|
# it is provided as a third-party extension mechanism.
|
||||||
|
if isinstance(field.type, str):
|
||||||
|
raise ImportError(
|
||||||
|
"This implementation is not compatible with Postponed Evaluation of Annotations (PEP 563),"
|
||||||
|
"which can be opted in from Python 3.7 with `from __future__ import annotations`."
|
||||||
|
"We will add compatibility when Python 3.9 is released."
|
||||||
|
)
|
||||||
|
typestring = str(field.type)
|
||||||
|
for x in (int, float, str):
|
||||||
|
if typestring == f"typing.Union[{x.__name__}, NoneType]":
|
||||||
|
field.type = x
|
||||||
|
if isinstance(field.type, type) and issubclass(field.type, Enum):
|
||||||
|
kwargs["choices"] = list(field.type)
|
||||||
|
kwargs["type"] = field.type
|
||||||
|
if field.default is not dataclasses.MISSING:
|
||||||
|
kwargs["default"] = field.default
|
||||||
|
elif field.type is bool:
|
||||||
|
kwargs["action"] = "store_false" if field.default is True else "store_true"
|
||||||
|
if field.default is True:
|
||||||
|
field_name = f"--no-{field.name}"
|
||||||
|
kwargs["dest"] = field.name
|
||||||
|
else:
|
||||||
|
kwargs["type"] = field.type
|
||||||
|
if field.default is not dataclasses.MISSING:
|
||||||
|
kwargs["default"] = field.default
|
||||||
|
else:
|
||||||
|
kwargs["required"] = True
|
||||||
|
self.add_argument(field_name, **kwargs)
|
||||||
|
|
||||||
|
def parse_args_into_dataclasses(self, args=None, return_remaining_strings=False) -> Tuple[DataClass, ...]:
|
||||||
|
"""
|
||||||
|
Parse command-line args into instances of the specified dataclass types.
|
||||||
|
|
||||||
|
This relies on argparse's `ArgumentParser.parse_known_args`.
|
||||||
|
See the doc at:
|
||||||
|
docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args:
|
||||||
|
List of strings to parse. The default is taken from sys.argv.
|
||||||
|
(same as argparse.ArgumentParser)
|
||||||
|
return_remaining_strings:
|
||||||
|
If true, also return a list of remaining argument strings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple consisting of:
|
||||||
|
- the dataclass instances in the same order as they
|
||||||
|
were passed to the initializer.abspath
|
||||||
|
- if applicable, an additional namespace for more
|
||||||
|
(non-dataclass backed) arguments added to the parser
|
||||||
|
after initialization.
|
||||||
|
- The potential list of remaining argument strings.
|
||||||
|
(same as argparse.ArgumentParser.parse_known_args)
|
||||||
|
"""
|
||||||
|
namespace, remaining_args = self.parse_known_args(args=args)
|
||||||
|
outputs = []
|
||||||
|
for dtype in self.dataclass_types:
|
||||||
|
keys = {f.name for f in dataclasses.fields(dtype)}
|
||||||
|
inputs = {k: v for k, v in vars(namespace).items() if k in keys}
|
||||||
|
for k in keys:
|
||||||
|
delattr(namespace, k)
|
||||||
|
obj = dtype(**inputs)
|
||||||
|
outputs.append(obj)
|
||||||
|
if len(namespace.__dict__) > 0:
|
||||||
|
# additional namespace.
|
||||||
|
outputs.append(namespace)
|
||||||
|
if return_remaining_strings:
|
||||||
|
return (*outputs, remaining_args)
|
||||||
|
else:
|
||||||
|
return (*outputs,)
|
||||||
75
src/transformers/training_args.py
Normal file
75
src/transformers/training_args.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingArguments:
|
||||||
|
"""
|
||||||
|
TrainingArguments is the subset of the arguments we use in our example scripts
|
||||||
|
**which relate to the training loop itself**.
|
||||||
|
|
||||||
|
Using `HfArgumentParser` we can turn this class
|
||||||
|
into argparse arguments to be able to specify them on
|
||||||
|
the command line.
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_dir: str = field(
|
||||||
|
metadata={"help": "The output directory where the model predictions and checkpoints will be written."}
|
||||||
|
)
|
||||||
|
overwrite_output_dir: bool = field(
|
||||||
|
default=False, metadata={"help": "Overwrite the content of the output directory"}
|
||||||
|
)
|
||||||
|
|
||||||
|
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
||||||
|
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
|
||||||
|
evaluate_during_training: bool = field(
|
||||||
|
default=False, metadata={"help": "Run evaluation during training at each logging step."}
|
||||||
|
)
|
||||||
|
|
||||||
|
per_gpu_train_batch_size: int = field(default=8, metadata={"help": "Batch size per GPU/CPU for training."})
|
||||||
|
per_gpu_eval_batch_size: int = field(default=8, metadata={"help": "Batch size per GPU/CPU for evaluation."})
|
||||||
|
gradient_accumulation_steps: int = field(
|
||||||
|
default=1, metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}
|
||||||
|
)
|
||||||
|
|
||||||
|
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for Adam."})
|
||||||
|
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay if we apply some."})
|
||||||
|
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for Adam optimizer."})
|
||||||
|
max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})
|
||||||
|
|
||||||
|
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
|
||||||
|
max_steps: int = field(
|
||||||
|
default=-1,
|
||||||
|
metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."},
|
||||||
|
)
|
||||||
|
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||||
|
|
||||||
|
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
||||||
|
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
||||||
|
save_total_limit: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
eval_all_checkpoints: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
no_cuda: bool = field(default=False, metadata={"help": "Avoid using CUDA even if it is available"})
|
||||||
|
seed: int = field(default=42, metadata={"help": "random seed for initialization"})
|
||||||
|
|
||||||
|
fp16: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"},
|
||||||
|
)
|
||||||
|
fp16_opt_level: str = field(
|
||||||
|
default="O1",
|
||||||
|
metadata={
|
||||||
|
"help": "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
|
"See details at https://nvidia.github.io/apex/amp.html"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
|
||||||
122
tests/test_hf_argparser.py
Normal file
122
tests/test_hf_argparser.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
import argparse
|
||||||
|
import unittest
|
||||||
|
from argparse import Namespace
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
|
from transformers.training_args import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BasicExample:
|
||||||
|
foo: int
|
||||||
|
bar: float
|
||||||
|
baz: str
|
||||||
|
flag: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WithDefaultExample:
|
||||||
|
foo: int = 42
|
||||||
|
baz: str = field(default="toto", metadata={"help": "help message"})
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WithDefaultBoolExample:
|
||||||
|
foo: bool = False
|
||||||
|
baz: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class BasicEnum(Enum):
|
||||||
|
titi = "titi"
|
||||||
|
toto = "toto"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EnumExample:
|
||||||
|
foo: BasicEnum = BasicEnum.toto
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OptionalExample:
|
||||||
|
foo: Optional[int] = None
|
||||||
|
bar: Optional[float] = field(default=None, metadata={"help": "help message"})
|
||||||
|
baz: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class HfArgumentParserTest(unittest.TestCase):
|
||||||
|
def argparsersEqual(self, a: argparse.ArgumentParser, b: argparse.ArgumentParser) -> bool:
|
||||||
|
"""
|
||||||
|
Small helper to check pseudo-equality of parsed arguments on `ArgumentParser` instances.
|
||||||
|
"""
|
||||||
|
self.assertEqual(len(a._actions), len(b._actions))
|
||||||
|
for x, y in zip(a._actions, b._actions):
|
||||||
|
xx = {k: v for k, v in vars(x).items() if k != "container"}
|
||||||
|
yy = {k: v for k, v in vars(y).items() if k != "container"}
|
||||||
|
self.assertEqual(xx, yy)
|
||||||
|
|
||||||
|
def test_basic(self):
|
||||||
|
parser = HfArgumentParser(BasicExample)
|
||||||
|
|
||||||
|
expected = argparse.ArgumentParser()
|
||||||
|
expected.add_argument("--foo", type=int, required=True)
|
||||||
|
expected.add_argument("--bar", type=float, required=True)
|
||||||
|
expected.add_argument("--baz", type=str, required=True)
|
||||||
|
expected.add_argument("--flag", action="store_true")
|
||||||
|
self.argparsersEqual(parser, expected)
|
||||||
|
|
||||||
|
def test_with_default(self):
|
||||||
|
parser = HfArgumentParser(WithDefaultExample)
|
||||||
|
|
||||||
|
expected = argparse.ArgumentParser()
|
||||||
|
expected.add_argument("--foo", default=42, type=int)
|
||||||
|
expected.add_argument("--baz", default="toto", type=str, help="help message")
|
||||||
|
self.argparsersEqual(parser, expected)
|
||||||
|
|
||||||
|
def test_with_default_bool(self):
|
||||||
|
parser = HfArgumentParser(WithDefaultBoolExample)
|
||||||
|
|
||||||
|
expected = argparse.ArgumentParser()
|
||||||
|
expected.add_argument("--foo", action="store_true")
|
||||||
|
expected.add_argument("--no-baz", action="store_false", dest="baz")
|
||||||
|
self.argparsersEqual(parser, expected)
|
||||||
|
|
||||||
|
args = parser.parse_args([])
|
||||||
|
self.assertEqual(args, Namespace(foo=False, baz=True))
|
||||||
|
|
||||||
|
args = parser.parse_args(["--foo", "--no-baz"])
|
||||||
|
self.assertEqual(args, Namespace(foo=True, baz=False))
|
||||||
|
|
||||||
|
def test_with_enum(self):
|
||||||
|
parser = HfArgumentParser(EnumExample)
|
||||||
|
|
||||||
|
expected = argparse.ArgumentParser()
|
||||||
|
expected.add_argument("--foo", default=BasicEnum.toto, choices=list(BasicEnum), type=BasicEnum)
|
||||||
|
self.argparsersEqual(parser, expected)
|
||||||
|
|
||||||
|
args = parser.parse_args([])
|
||||||
|
self.assertEqual(args.foo, BasicEnum.toto)
|
||||||
|
|
||||||
|
args = parser.parse_args(["--foo", "titi"])
|
||||||
|
self.assertEqual(args.foo, BasicEnum.titi)
|
||||||
|
|
||||||
|
def test_with_optional(self):
|
||||||
|
parser = HfArgumentParser(OptionalExample)
|
||||||
|
|
||||||
|
expected = argparse.ArgumentParser()
|
||||||
|
expected.add_argument("--foo", default=None, type=int)
|
||||||
|
expected.add_argument("--bar", default=None, type=float, help="help message")
|
||||||
|
expected.add_argument("--baz", default=None, type=str)
|
||||||
|
self.argparsersEqual(parser, expected)
|
||||||
|
|
||||||
|
args = parser.parse_args([])
|
||||||
|
self.assertEqual(args, Namespace(foo=None, bar=None, baz=None))
|
||||||
|
|
||||||
|
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42".split())
|
||||||
|
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42"))
|
||||||
|
|
||||||
|
def test_integration_training_args(self):
|
||||||
|
parser = HfArgumentParser(TrainingArguments)
|
||||||
|
self.assertIsNotNone(parser)
|
||||||
Reference in New Issue
Block a user