[examples] Generate argparsers from type hints on dataclasses (#3669)
* [examples] Generate argparsers from type hints on dataclasses * [HfArgumentParser] way simpler API * Restore run_language_modeling.py for easier diff * [HfArgumentParser] final tweaks from code review
This commit is contained in:
@@ -88,6 +88,7 @@ from .file_utils import (
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
from .hf_argparser import HfArgumentParser
|
||||
|
||||
# Model Cards
|
||||
from .modelcard import ModelCard
|
||||
@@ -141,6 +142,7 @@ from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_xlm import XLMTokenizer
|
||||
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
|
||||
from .training_args import TrainingArguments
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
113
src/transformers/hf_argparser.py
Normal file
113
src/transformers/hf_argparser.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import dataclasses
|
||||
from argparse import ArgumentParser
|
||||
from enum import Enum
|
||||
from typing import Any, Iterable, NewType, Tuple, Union
|
||||
|
||||
|
||||
DataClass = NewType("DataClass", Any)
|
||||
DataClassType = NewType("DataClassType", Any)
|
||||
|
||||
|
||||
class HfArgumentParser(ArgumentParser):
|
||||
"""
|
||||
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses
|
||||
to generate arguments.
|
||||
|
||||
The class is designed to play well with the native argparse. In particular,
|
||||
you can add more (non-dataclass backed) arguments to the parser after initialization
|
||||
and you'll get the output back after parsing as an additional namespace.
|
||||
"""
|
||||
|
||||
dataclass_types: Iterable[DataClassType]
|
||||
|
||||
def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
|
||||
"""
|
||||
Args:
|
||||
dataclass_types:
|
||||
Dataclass type, or list of dataclass types for which we will "fill" instances
|
||||
with the parsed args.
|
||||
kwargs:
|
||||
(Optional) Passed to `argparse.ArgumentParser()` in the regular way.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if dataclasses.is_dataclass(dataclass_types):
|
||||
dataclass_types = [dataclass_types]
|
||||
self.dataclass_types = dataclass_types
|
||||
for dtype in self.dataclass_types:
|
||||
self._add_dataclass_arguments(dtype)
|
||||
|
||||
def _add_dataclass_arguments(self, dtype: DataClassType):
|
||||
for field in dataclasses.fields(dtype):
|
||||
field_name = f"--{field.name}"
|
||||
kwargs = field.metadata.copy()
|
||||
# field.metadata is not used at all by Data Classes,
|
||||
# it is provided as a third-party extension mechanism.
|
||||
if isinstance(field.type, str):
|
||||
raise ImportError(
|
||||
"This implementation is not compatible with Postponed Evaluation of Annotations (PEP 563),"
|
||||
"which can be opted in from Python 3.7 with `from __future__ import annotations`."
|
||||
"We will add compatibility when Python 3.9 is released."
|
||||
)
|
||||
typestring = str(field.type)
|
||||
for x in (int, float, str):
|
||||
if typestring == f"typing.Union[{x.__name__}, NoneType]":
|
||||
field.type = x
|
||||
if isinstance(field.type, type) and issubclass(field.type, Enum):
|
||||
kwargs["choices"] = list(field.type)
|
||||
kwargs["type"] = field.type
|
||||
if field.default is not dataclasses.MISSING:
|
||||
kwargs["default"] = field.default
|
||||
elif field.type is bool:
|
||||
kwargs["action"] = "store_false" if field.default is True else "store_true"
|
||||
if field.default is True:
|
||||
field_name = f"--no-{field.name}"
|
||||
kwargs["dest"] = field.name
|
||||
else:
|
||||
kwargs["type"] = field.type
|
||||
if field.default is not dataclasses.MISSING:
|
||||
kwargs["default"] = field.default
|
||||
else:
|
||||
kwargs["required"] = True
|
||||
self.add_argument(field_name, **kwargs)
|
||||
|
||||
def parse_args_into_dataclasses(self, args=None, return_remaining_strings=False) -> Tuple[DataClass, ...]:
|
||||
"""
|
||||
Parse command-line args into instances of the specified dataclass types.
|
||||
|
||||
This relies on argparse's `ArgumentParser.parse_known_args`.
|
||||
See the doc at:
|
||||
docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args
|
||||
|
||||
Args:
|
||||
args:
|
||||
List of strings to parse. The default is taken from sys.argv.
|
||||
(same as argparse.ArgumentParser)
|
||||
return_remaining_strings:
|
||||
If true, also return a list of remaining argument strings.
|
||||
|
||||
Returns:
|
||||
Tuple consisting of:
|
||||
- the dataclass instances in the same order as they
|
||||
were passed to the initializer.abspath
|
||||
- if applicable, an additional namespace for more
|
||||
(non-dataclass backed) arguments added to the parser
|
||||
after initialization.
|
||||
- The potential list of remaining argument strings.
|
||||
(same as argparse.ArgumentParser.parse_known_args)
|
||||
"""
|
||||
namespace, remaining_args = self.parse_known_args(args=args)
|
||||
outputs = []
|
||||
for dtype in self.dataclass_types:
|
||||
keys = {f.name for f in dataclasses.fields(dtype)}
|
||||
inputs = {k: v for k, v in vars(namespace).items() if k in keys}
|
||||
for k in keys:
|
||||
delattr(namespace, k)
|
||||
obj = dtype(**inputs)
|
||||
outputs.append(obj)
|
||||
if len(namespace.__dict__) > 0:
|
||||
# additional namespace.
|
||||
outputs.append(namespace)
|
||||
if return_remaining_strings:
|
||||
return (*outputs, remaining_args)
|
||||
else:
|
||||
return (*outputs,)
|
||||
75
src/transformers/training_args.py
Normal file
75
src/transformers/training_args.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments:
|
||||
"""
|
||||
TrainingArguments is the subset of the arguments we use in our example scripts
|
||||
**which relate to the training loop itself**.
|
||||
|
||||
Using `HfArgumentParser` we can turn this class
|
||||
into argparse arguments to be able to specify them on
|
||||
the command line.
|
||||
"""
|
||||
|
||||
output_dir: str = field(
|
||||
metadata={"help": "The output directory where the model predictions and checkpoints will be written."}
|
||||
)
|
||||
overwrite_output_dir: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the content of the output directory"}
|
||||
)
|
||||
|
||||
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
||||
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
|
||||
evaluate_during_training: bool = field(
|
||||
default=False, metadata={"help": "Run evaluation during training at each logging step."}
|
||||
)
|
||||
|
||||
per_gpu_train_batch_size: int = field(default=8, metadata={"help": "Batch size per GPU/CPU for training."})
|
||||
per_gpu_eval_batch_size: int = field(default=8, metadata={"help": "Batch size per GPU/CPU for evaluation."})
|
||||
gradient_accumulation_steps: int = field(
|
||||
default=1, metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}
|
||||
)
|
||||
|
||||
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for Adam."})
|
||||
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay if we apply some."})
|
||||
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for Adam optimizer."})
|
||||
max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})
|
||||
|
||||
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
|
||||
max_steps: int = field(
|
||||
default=-1,
|
||||
metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."},
|
||||
)
|
||||
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||
|
||||
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
||||
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
||||
save_total_limit: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default"
|
||||
},
|
||||
)
|
||||
eval_all_checkpoints: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
|
||||
},
|
||||
)
|
||||
no_cuda: bool = field(default=False, metadata={"help": "Avoid using CUDA even if it is available"})
|
||||
seed: int = field(default=42, metadata={"help": "random seed for initialization"})
|
||||
|
||||
fp16: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"},
|
||||
)
|
||||
fp16_opt_level: str = field(
|
||||
default="O1",
|
||||
metadata={
|
||||
"help": "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html"
|
||||
},
|
||||
)
|
||||
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
|
||||
Reference in New Issue
Block a user