From b169ac9c2ba62c828000516dbce1af9126ca25ab Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 10 Apr 2020 12:21:58 -0400 Subject: [PATCH] [examples] Generate argparsers from type hints on dataclasses (#3669) * [examples] Generate argparsers from type hints on dataclasses * [HfArgumentParser] way simpler API * Restore run_language_modeling.py for easier diff * [HfArgumentParser] final tweaks from code review --- examples/run_glue.py | 178 ++++++++---------------------- src/transformers/__init__.py | 2 + src/transformers/hf_argparser.py | 113 +++++++++++++++++++ src/transformers/training_args.py | 75 +++++++++++++ tests/test_hf_argparser.py | 122 ++++++++++++++++++++ 5 files changed, 355 insertions(+), 135 deletions(-) create mode 100644 src/transformers/hf_argparser.py create mode 100644 src/transformers/training_args.py create mode 100644 tests/test_hf_argparser.py diff --git a/examples/run_glue.py b/examples/run_glue.py index 130bb19a82..1f586b7c56 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -22,6 +22,8 @@ import json import logging import os import random +from dataclasses import dataclass, field +from typing import Optional import numpy as np import torch @@ -36,6 +38,8 @@ from transformers import ( AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, + HfArgumentParser, + TrainingArguments, get_linear_schedule_with_warmup, ) from transformers import glue_compute_metrics as compute_metrics @@ -376,137 +380,54 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): return dataset -def main(): - parser = argparse.ArgumentParser() +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ - # Required parameters - parser.add_argument( - "--data_dir", - default=None, - type=str, - required=True, - help="The input data dir. Should contain the .tsv files (or other data files) for the task.", + model_name_or_path: str = field( + metadata={"help": "Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)} ) - parser.add_argument( - "--model_type", - default=None, - type=str, - required=True, - help="Model type selected in the list: " + ", ".join(MODEL_TYPES), + model_type: str = field(metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_TYPES)}) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) - parser.add_argument( - "--model_name_or_path", - default=None, - type=str, - required=True, - help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS), + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} ) - parser.add_argument( - "--task_name", - default=None, - type=str, - required=True, - help="The name of the task to train selected in the list: " + ", ".join(processors.keys()), - ) - parser.add_argument( - "--output_dir", - default=None, - type=str, - required=True, - help="The output directory where the model predictions and checkpoints will be written.", + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pre-trained models downloaded from s3"} ) - # Other parameters - parser.add_argument( - "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name", + +@dataclass +class DataProcessingArguments: + task_name: str = field( + metadata={"help": "The name of the task to train selected in the list: " + ", ".join(processors.keys())} ) - parser.add_argument( - "--tokenizer_name", - default="", - type=str, - help="Pretrained tokenizer name or path if not the same as model_name", + data_dir: str = field( + metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."} ) - parser.add_argument( - "--cache_dir", - default="", - type=str, - help="Where do you want to store the pre-trained models downloaded from s3", - ) - parser.add_argument( - "--max_seq_length", + max_seq_length: int = field( default=128, - type=int, - help="The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded.", + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, ) - parser.add_argument("--do_train", action="store_true", help="Whether to run training.") - parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") - parser.add_argument( - "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step.", - ) - parser.add_argument( - "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.", + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) - parser.add_argument( - "--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.", - ) - parser.add_argument( - "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.", - ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) - parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") - parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") - parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument( - "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.", - ) - parser.add_argument( - "--max_steps", - default=-1, - type=int, - help="If > 0: set total number of training steps to perform. Override num_train_epochs.", - ) - parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") - parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.") - parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") - parser.add_argument( - "--eval_all_checkpoints", - action="store_true", - help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number", - ) - parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") - parser.add_argument( - "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory", - ) - parser.add_argument( - "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets", - ) - parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") +def main(): + parser = HfArgumentParser((ModelArguments, DataProcessingArguments, TrainingArguments)) + model_args, dataprocessing_args, training_args = parser.parse_args_into_dataclasses() - parser.add_argument( - "--fp16", - action="store_true", - help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", - ) - parser.add_argument( - "--fp16_opt_level", - type=str, - default="O1", - help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." - "See details at https://nvidia.github.io/apex/amp.html", - ) - parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") - parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.") - parser.add_argument("--server_port", type=str, default="", help="For distant debugging.") - args = parser.parse_args() + # For now, let's merge all the sets of args into one, + # but soon, we'll keep distinct sets of args, with a cleaner separation of concerns. + args = argparse.Namespace(**vars(model_args), **vars(dataprocessing_args), **vars(training_args)) if ( os.path.exists(args.output_dir) @@ -515,20 +436,9 @@ def main(): and not args.overwrite_output_dir ): raise ValueError( - "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( - args.output_dir - ) + f"Output directory ({args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." ) - # Setup distant debugging if needed - if args.server_ip and args.server_port: - # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script - import ptvsd - - print("Waiting for debugger attach") - ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) - ptvsd.wait_for_attach() - # Setup CUDA, GPU & distributed training if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") @@ -576,18 +486,16 @@ def main(): args.config_name if args.config_name else args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name, - cache_dir=args.cache_dir if args.cache_dir else None, + cache_dir=args.cache_dir, ) tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, - do_lower_case=args.do_lower_case, - cache_dir=args.cache_dir if args.cache_dir else None, + args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, cache_dir=args.cache_dir, ) model = AutoModelForSequenceClassification.from_pretrained( args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, - cache_dir=args.cache_dir if args.cache_dir else None, + cache_dir=args.cache_dir, ) if args.local_rank == 0: @@ -629,7 +537,7 @@ def main(): # Evaluation results = {} if args.do_eval and args.local_rank in [-1, 0]: - tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) + tokenizer = AutoTokenizer.from_pretrained(args.output_dir) checkpoints = [args.output_dir] if args.eval_all_checkpoints: checkpoints = list( diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 206020ddf0..01f5600164 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -88,6 +88,7 @@ from .file_utils import ( is_tf_available, is_torch_available, ) +from .hf_argparser import HfArgumentParser # Model Cards from .modelcard import ModelCard @@ -141,6 +142,7 @@ from .tokenization_utils import PreTrainedTokenizer from .tokenization_xlm import XLMTokenizer from .tokenization_xlm_roberta import XLMRobertaTokenizer from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer +from .training_args import TrainingArguments logger = logging.getLogger(__name__) # pylint: disable=invalid-name diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py new file mode 100644 index 0000000000..35dc83d7ca --- /dev/null +++ b/src/transformers/hf_argparser.py @@ -0,0 +1,113 @@ +import dataclasses +from argparse import ArgumentParser +from enum import Enum +from typing import Any, Iterable, NewType, Tuple, Union + + +DataClass = NewType("DataClass", Any) +DataClassType = NewType("DataClassType", Any) + + +class HfArgumentParser(ArgumentParser): + """ + This subclass of `argparse.ArgumentParser` uses type hints on dataclasses + to generate arguments. + + The class is designed to play well with the native argparse. In particular, + you can add more (non-dataclass backed) arguments to the parser after initialization + and you'll get the output back after parsing as an additional namespace. + """ + + dataclass_types: Iterable[DataClassType] + + def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs): + """ + Args: + dataclass_types: + Dataclass type, or list of dataclass types for which we will "fill" instances + with the parsed args. + kwargs: + (Optional) Passed to `argparse.ArgumentParser()` in the regular way. + """ + super().__init__(**kwargs) + if dataclasses.is_dataclass(dataclass_types): + dataclass_types = [dataclass_types] + self.dataclass_types = dataclass_types + for dtype in self.dataclass_types: + self._add_dataclass_arguments(dtype) + + def _add_dataclass_arguments(self, dtype: DataClassType): + for field in dataclasses.fields(dtype): + field_name = f"--{field.name}" + kwargs = field.metadata.copy() + # field.metadata is not used at all by Data Classes, + # it is provided as a third-party extension mechanism. + if isinstance(field.type, str): + raise ImportError( + "This implementation is not compatible with Postponed Evaluation of Annotations (PEP 563)," + "which can be opted in from Python 3.7 with `from __future__ import annotations`." + "We will add compatibility when Python 3.9 is released." + ) + typestring = str(field.type) + for x in (int, float, str): + if typestring == f"typing.Union[{x.__name__}, NoneType]": + field.type = x + if isinstance(field.type, type) and issubclass(field.type, Enum): + kwargs["choices"] = list(field.type) + kwargs["type"] = field.type + if field.default is not dataclasses.MISSING: + kwargs["default"] = field.default + elif field.type is bool: + kwargs["action"] = "store_false" if field.default is True else "store_true" + if field.default is True: + field_name = f"--no-{field.name}" + kwargs["dest"] = field.name + else: + kwargs["type"] = field.type + if field.default is not dataclasses.MISSING: + kwargs["default"] = field.default + else: + kwargs["required"] = True + self.add_argument(field_name, **kwargs) + + def parse_args_into_dataclasses(self, args=None, return_remaining_strings=False) -> Tuple[DataClass, ...]: + """ + Parse command-line args into instances of the specified dataclass types. + + This relies on argparse's `ArgumentParser.parse_known_args`. + See the doc at: + docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args + + Args: + args: + List of strings to parse. The default is taken from sys.argv. + (same as argparse.ArgumentParser) + return_remaining_strings: + If true, also return a list of remaining argument strings. + + Returns: + Tuple consisting of: + - the dataclass instances in the same order as they + were passed to the initializer.abspath + - if applicable, an additional namespace for more + (non-dataclass backed) arguments added to the parser + after initialization. + - The potential list of remaining argument strings. + (same as argparse.ArgumentParser.parse_known_args) + """ + namespace, remaining_args = self.parse_known_args(args=args) + outputs = [] + for dtype in self.dataclass_types: + keys = {f.name for f in dataclasses.fields(dtype)} + inputs = {k: v for k, v in vars(namespace).items() if k in keys} + for k in keys: + delattr(namespace, k) + obj = dtype(**inputs) + outputs.append(obj) + if len(namespace.__dict__) > 0: + # additional namespace. + outputs.append(namespace) + if return_remaining_strings: + return (*outputs, remaining_args) + else: + return (*outputs,) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py new file mode 100644 index 0000000000..b48486dfb0 --- /dev/null +++ b/src/transformers/training_args.py @@ -0,0 +1,75 @@ +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class TrainingArguments: + """ + TrainingArguments is the subset of the arguments we use in our example scripts + **which relate to the training loop itself**. + + Using `HfArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + output_dir: str = field( + metadata={"help": "The output directory where the model predictions and checkpoints will be written."} + ) + overwrite_output_dir: bool = field( + default=False, metadata={"help": "Overwrite the content of the output directory"} + ) + + do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) + do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) + evaluate_during_training: bool = field( + default=False, metadata={"help": "Run evaluation during training at each logging step."} + ) + + per_gpu_train_batch_size: int = field(default=8, metadata={"help": "Batch size per GPU/CPU for training."}) + per_gpu_eval_batch_size: int = field(default=8, metadata={"help": "Batch size per GPU/CPU for evaluation."}) + gradient_accumulation_steps: int = field( + default=1, metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."} + ) + + learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for Adam."}) + weight_decay: float = field(default=0.0, metadata={"help": "Weight decay if we apply some."}) + adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for Adam optimizer."}) + max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."}) + + num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."}) + max_steps: int = field( + default=-1, + metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."}, + ) + warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) + + logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) + save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."}) + save_total_limit: Optional[int] = field( + default=None, + metadata={ + "help": "Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default" + }, + ) + eval_all_checkpoints: bool = field( + default=False, + metadata={ + "help": "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number" + }, + ) + no_cuda: bool = field(default=False, metadata={"help": "Avoid using CUDA even if it is available"}) + seed: int = field(default=42, metadata={"help": "random seed for initialization"}) + + fp16: bool = field( + default=False, + metadata={"help": "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"}, + ) + fp16_opt_level: str = field( + default="O1", + metadata={ + "help": "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." + "See details at https://nvidia.github.io/apex/amp.html" + }, + ) + local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"}) diff --git a/tests/test_hf_argparser.py b/tests/test_hf_argparser.py new file mode 100644 index 0000000000..232d2e86af --- /dev/null +++ b/tests/test_hf_argparser.py @@ -0,0 +1,122 @@ +import argparse +import unittest +from argparse import Namespace +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional + +from transformers.hf_argparser import HfArgumentParser +from transformers.training_args import TrainingArguments + + +@dataclass +class BasicExample: + foo: int + bar: float + baz: str + flag: bool + + +@dataclass +class WithDefaultExample: + foo: int = 42 + baz: str = field(default="toto", metadata={"help": "help message"}) + + +@dataclass +class WithDefaultBoolExample: + foo: bool = False + baz: bool = True + + +class BasicEnum(Enum): + titi = "titi" + toto = "toto" + + +@dataclass +class EnumExample: + foo: BasicEnum = BasicEnum.toto + + +@dataclass +class OptionalExample: + foo: Optional[int] = None + bar: Optional[float] = field(default=None, metadata={"help": "help message"}) + baz: Optional[str] = None + + +class HfArgumentParserTest(unittest.TestCase): + def argparsersEqual(self, a: argparse.ArgumentParser, b: argparse.ArgumentParser) -> bool: + """ + Small helper to check pseudo-equality of parsed arguments on `ArgumentParser` instances. + """ + self.assertEqual(len(a._actions), len(b._actions)) + for x, y in zip(a._actions, b._actions): + xx = {k: v for k, v in vars(x).items() if k != "container"} + yy = {k: v for k, v in vars(y).items() if k != "container"} + self.assertEqual(xx, yy) + + def test_basic(self): + parser = HfArgumentParser(BasicExample) + + expected = argparse.ArgumentParser() + expected.add_argument("--foo", type=int, required=True) + expected.add_argument("--bar", type=float, required=True) + expected.add_argument("--baz", type=str, required=True) + expected.add_argument("--flag", action="store_true") + self.argparsersEqual(parser, expected) + + def test_with_default(self): + parser = HfArgumentParser(WithDefaultExample) + + expected = argparse.ArgumentParser() + expected.add_argument("--foo", default=42, type=int) + expected.add_argument("--baz", default="toto", type=str, help="help message") + self.argparsersEqual(parser, expected) + + def test_with_default_bool(self): + parser = HfArgumentParser(WithDefaultBoolExample) + + expected = argparse.ArgumentParser() + expected.add_argument("--foo", action="store_true") + expected.add_argument("--no-baz", action="store_false", dest="baz") + self.argparsersEqual(parser, expected) + + args = parser.parse_args([]) + self.assertEqual(args, Namespace(foo=False, baz=True)) + + args = parser.parse_args(["--foo", "--no-baz"]) + self.assertEqual(args, Namespace(foo=True, baz=False)) + + def test_with_enum(self): + parser = HfArgumentParser(EnumExample) + + expected = argparse.ArgumentParser() + expected.add_argument("--foo", default=BasicEnum.toto, choices=list(BasicEnum), type=BasicEnum) + self.argparsersEqual(parser, expected) + + args = parser.parse_args([]) + self.assertEqual(args.foo, BasicEnum.toto) + + args = parser.parse_args(["--foo", "titi"]) + self.assertEqual(args.foo, BasicEnum.titi) + + def test_with_optional(self): + parser = HfArgumentParser(OptionalExample) + + expected = argparse.ArgumentParser() + expected.add_argument("--foo", default=None, type=int) + expected.add_argument("--bar", default=None, type=float, help="help message") + expected.add_argument("--baz", default=None, type=str) + self.argparsersEqual(parser, expected) + + args = parser.parse_args([]) + self.assertEqual(args, Namespace(foo=None, bar=None, baz=None)) + + args = parser.parse_args("--foo 12 --bar 3.14 --baz 42".split()) + self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42")) + + def test_integration_training_args(self): + parser = HfArgumentParser(TrainingArguments) + self.assertIsNotNone(parser)