From 068e6b5eddb5dba296f0c42f36fffb9368c9fe24 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 3 Nov 2020 21:13:33 +0100 Subject: [PATCH] make files independent (#8267) --- examples/seq2seq/finetune_trainer.py | 48 ++--------------------- examples/seq2seq/seq2seq_trainer.py | 1 - examples/seq2seq/seq2seq_training_args.py | 45 +++++++++++++++++++++ 3 files changed, 48 insertions(+), 46 deletions(-) create mode 100644 examples/seq2seq/seq2seq_training_args.py diff --git a/examples/seq2seq/finetune_trainer.py b/examples/seq2seq/finetune_trainer.py index 29632ea91f..dd394365e9 100644 --- a/examples/seq2seq/finetune_trainer.py +++ b/examples/seq2seq/finetune_trainer.py @@ -4,16 +4,9 @@ import sys from dataclasses import dataclass, field from typing import Optional -from seq2seq_trainer import Seq2SeqTrainer, arg_to_scheduler_choices -from transformers import ( - AutoConfig, - AutoModelForSeq2SeqLM, - AutoTokenizer, - HfArgumentParser, - MBartTokenizer, - TrainingArguments, - set_seed, -) +from seq2seq_trainer import Seq2SeqTrainer +from seq2seq_training_args import Seq2SeqTrainingArguments +from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser, MBartTokenizer, set_seed from transformers.trainer_utils import EvaluationStrategy from utils import ( Seq2SeqDataCollator, @@ -33,41 +26,6 @@ from utils import ( logger = logging.getLogger(__name__) -@dataclass -class Seq2SeqTrainingArguments(TrainingArguments): - """ - Parameters: - label_smoothing (:obj:`float`, `optional`, defaults to 0): - The label smoothing epsilon to apply (if not zero). - sortish_sampler (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether to SortishSamler or not. It sorts the inputs according to lenghts in-order to minimizing the padding size. - predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether to use generate to calculate generative metrics (ROUGE, BLEU). - """ - - label_smoothing: Optional[float] = field( - default=0.0, metadata={"help": "The label smoothing epsilon to apply (if not zero)."} - ) - sortish_sampler: bool = field(default=False, metadata={"help": "Whether to SortishSamler or not."}) - predict_with_generate: bool = field( - default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} - ) - adafactor: bool = field(default=False, metadata={"help": "whether to use adafactor"}) - encoder_layerdrop: Optional[float] = field( - default=None, metadata={"help": "Encoder layer dropout probability. Goes into model.config."} - ) - decoder_layerdrop: Optional[float] = field( - default=None, metadata={"help": "Decoder layer dropout probability. Goes into model.config."} - ) - dropout: Optional[float] = field(default=None, metadata={"help": "Dropout probability. Goes into model.config."}) - attention_dropout: Optional[float] = field( - default=None, metadata={"help": "Attention dropout probability. Goes into model.config."} - ) - lr_scheduler: Optional[str] = field( - default="linear", metadata={"help": f"Which lr scheduler to use. Selected in {arg_to_scheduler_choices}"} - ) - - @dataclass class ModelArguments: """ diff --git a/examples/seq2seq/seq2seq_trainer.py b/examples/seq2seq/seq2seq_trainer.py index dc4aeca593..805a73871f 100644 --- a/examples/seq2seq/seq2seq_trainer.py +++ b/examples/seq2seq/seq2seq_trainer.py @@ -30,7 +30,6 @@ arg_to_scheduler = { "constant": get_constant_schedule, "constant_w_warmup": get_constant_schedule_with_warmup, } -arg_to_scheduler_choices = sorted(arg_to_scheduler.keys()) class Seq2SeqTrainer(Trainer): diff --git a/examples/seq2seq/seq2seq_training_args.py b/examples/seq2seq/seq2seq_training_args.py new file mode 100644 index 0000000000..0bd486026a --- /dev/null +++ b/examples/seq2seq/seq2seq_training_args.py @@ -0,0 +1,45 @@ +import logging +from dataclasses import dataclass, field +from typing import Optional + +from seq2seq_trainer import arg_to_scheduler +from transformers import TrainingArguments + + +logger = logging.getLogger(__name__) + + +@dataclass +class Seq2SeqTrainingArguments(TrainingArguments): + """ + Parameters: + label_smoothing (:obj:`float`, `optional`, defaults to 0): + The label smoothing epsilon to apply (if not zero). + sortish_sampler (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to SortishSamler or not. It sorts the inputs according to lenghts in-order to minimizing the padding size. + predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to use generate to calculate generative metrics (ROUGE, BLEU). + """ + + label_smoothing: Optional[float] = field( + default=0.0, metadata={"help": "The label smoothing epsilon to apply (if not zero)."} + ) + sortish_sampler: bool = field(default=False, metadata={"help": "Whether to SortishSamler or not."}) + predict_with_generate: bool = field( + default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} + ) + adafactor: bool = field(default=False, metadata={"help": "whether to use adafactor"}) + encoder_layerdrop: Optional[float] = field( + default=None, metadata={"help": "Encoder layer dropout probability. Goes into model.config."} + ) + decoder_layerdrop: Optional[float] = field( + default=None, metadata={"help": "Decoder layer dropout probability. Goes into model.config."} + ) + dropout: Optional[float] = field(default=None, metadata={"help": "Dropout probability. Goes into model.config."}) + attention_dropout: Optional[float] = field( + default=None, metadata={"help": "Attention dropout probability. Goes into model.config."} + ) + lr_scheduler: Optional[str] = field( + default="linear", + metadata={"help": f"Which lr scheduler to use. Selected in {sorted(arg_to_scheduler.keys())}"}, + )