From 06a973fd2aa15f3818a57dc2bf678643904e494e Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Thu, 8 Oct 2020 22:36:35 +0530 Subject: [PATCH] [s2s] configure lr_scheduler from command line (#7641) --- examples/seq2seq/finetune_trainer.py | 5 +++- examples/seq2seq/seq2seq_trainer.py | 35 ++++++++++++++++++++++++++-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/examples/seq2seq/finetune_trainer.py b/examples/seq2seq/finetune_trainer.py index 5660f38360..39f5b7f55b 100644 --- a/examples/seq2seq/finetune_trainer.py +++ b/examples/seq2seq/finetune_trainer.py @@ -4,7 +4,7 @@ import sys from dataclasses import dataclass, field from typing import Optional -from seq2seq_trainer import Seq2SeqTrainer +from seq2seq_trainer import Seq2SeqTrainer, arg_to_scheduler_choices from transformers import ( AutoConfig, AutoModelForSeq2SeqLM, @@ -63,6 +63,9 @@ class Seq2SeqTrainingArguments(TrainingArguments): attention_dropout: Optional[float] = field( default=None, metadata={"help": "Attention dropout probability. Goes into model.config."} ) + lr_scheduler: Optional[str] = field( + default="linear", metadata={"help": f"Which lr scheduler to use. Selected in {arg_to_scheduler_choices}"} + ) @dataclass diff --git a/examples/seq2seq/seq2seq_trainer.py b/examples/seq2seq/seq2seq_trainer.py index 293244df41..c484499d26 100644 --- a/examples/seq2seq/seq2seq_trainer.py +++ b/examples/seq2seq/seq2seq_trainer.py @@ -8,7 +8,16 @@ from torch.utils.data import DistributedSampler, RandomSampler from transformers import Trainer from transformers.configuration_fsmt import FSMTConfig from transformers.file_utils import is_torch_tpu_available -from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup +from transformers.optimization import ( + Adafactor, + AdamW, + get_constant_schedule, + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + get_cosine_with_hard_restarts_schedule_with_warmup, + get_linear_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, +) from transformers.trainer_pt_utils import get_tpu_sampler @@ -20,6 +29,16 @@ except ImportError: logger = logging.getLogger(__name__) +arg_to_scheduler = { + "linear": get_linear_schedule_with_warmup, + "cosine": get_cosine_schedule_with_warmup, + "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, + "polynomial": get_polynomial_decay_schedule_with_warmup, + "constant": get_constant_schedule, + "constant_w_warmup": get_constant_schedule_with_warmup, +} +arg_to_scheduler_choices = sorted(arg_to_scheduler.keys()) + class Seq2SeqTrainer(Trainer): def __init__(self, config, data_args, *args, **kwargs): @@ -62,9 +81,21 @@ class Seq2SeqTrainer(Trainer): ) if self.lr_scheduler is None: - self.lr_scheduler = get_linear_schedule_with_warmup( + self.lr_scheduler = self._get_lr_scheduler(num_training_steps) + else: # ignoring --lr_scheduler + logger.warn("scheduler is passed to `Seq2SeqTrainer`, `--lr_scheduler` arg is ignored.") + + def _get_lr_scheduler(self, num_training_steps): + schedule_func = arg_to_scheduler[self.args.lr_scheduler] + if self.args.lr_scheduler == "constant": + scheduler = schedule_func(self.optimizer) + elif self.args.lr_scheduler == "constant_w_warmup": + scheduler = schedule_func(self.optimizer, num_warmup_steps=self.args.warmup_steps) + else: + scheduler = schedule_func( self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps ) + return scheduler def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: if isinstance(self.train_dataset, torch.utils.data.IterableDataset):