[s2s] configure lr_scheduler from command line (#7641)
This commit is contained in:
@@ -4,7 +4,7 @@ import sys
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from seq2seq_trainer import Seq2SeqTrainer
|
from seq2seq_trainer import Seq2SeqTrainer, arg_to_scheduler_choices
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
@@ -63,6 +63,9 @@ class Seq2SeqTrainingArguments(TrainingArguments):
|
|||||||
attention_dropout: Optional[float] = field(
|
attention_dropout: Optional[float] = field(
|
||||||
default=None, metadata={"help": "Attention dropout probability. Goes into model.config."}
|
default=None, metadata={"help": "Attention dropout probability. Goes into model.config."}
|
||||||
)
|
)
|
||||||
|
lr_scheduler: Optional[str] = field(
|
||||||
|
default="linear", metadata={"help": f"Which lr scheduler to use. Selected in {arg_to_scheduler_choices}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -8,7 +8,16 @@ from torch.utils.data import DistributedSampler, RandomSampler
|
|||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.configuration_fsmt import FSMTConfig
|
from transformers.configuration_fsmt import FSMTConfig
|
||||||
from transformers.file_utils import is_torch_tpu_available
|
from transformers.file_utils import is_torch_tpu_available
|
||||||
from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup
|
from transformers.optimization import (
|
||||||
|
Adafactor,
|
||||||
|
AdamW,
|
||||||
|
get_constant_schedule,
|
||||||
|
get_constant_schedule_with_warmup,
|
||||||
|
get_cosine_schedule_with_warmup,
|
||||||
|
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||||
|
get_linear_schedule_with_warmup,
|
||||||
|
get_polynomial_decay_schedule_with_warmup,
|
||||||
|
)
|
||||||
from transformers.trainer_pt_utils import get_tpu_sampler
|
from transformers.trainer_pt_utils import get_tpu_sampler
|
||||||
|
|
||||||
|
|
||||||
@@ -20,6 +29,16 @@ except ImportError:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
arg_to_scheduler = {
|
||||||
|
"linear": get_linear_schedule_with_warmup,
|
||||||
|
"cosine": get_cosine_schedule_with_warmup,
|
||||||
|
"cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||||
|
"polynomial": get_polynomial_decay_schedule_with_warmup,
|
||||||
|
"constant": get_constant_schedule,
|
||||||
|
"constant_w_warmup": get_constant_schedule_with_warmup,
|
||||||
|
}
|
||||||
|
arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
|
||||||
|
|
||||||
|
|
||||||
class Seq2SeqTrainer(Trainer):
|
class Seq2SeqTrainer(Trainer):
|
||||||
def __init__(self, config, data_args, *args, **kwargs):
|
def __init__(self, config, data_args, *args, **kwargs):
|
||||||
@@ -62,9 +81,21 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.lr_scheduler is None:
|
if self.lr_scheduler is None:
|
||||||
self.lr_scheduler = get_linear_schedule_with_warmup(
|
self.lr_scheduler = self._get_lr_scheduler(num_training_steps)
|
||||||
|
else: # ignoring --lr_scheduler
|
||||||
|
logger.warn("scheduler is passed to `Seq2SeqTrainer`, `--lr_scheduler` arg is ignored.")
|
||||||
|
|
||||||
|
def _get_lr_scheduler(self, num_training_steps):
|
||||||
|
schedule_func = arg_to_scheduler[self.args.lr_scheduler]
|
||||||
|
if self.args.lr_scheduler == "constant":
|
||||||
|
scheduler = schedule_func(self.optimizer)
|
||||||
|
elif self.args.lr_scheduler == "constant_w_warmup":
|
||||||
|
scheduler = schedule_func(self.optimizer, num_warmup_steps=self.args.warmup_steps)
|
||||||
|
else:
|
||||||
|
scheduler = schedule_func(
|
||||||
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
|
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
|
||||||
)
|
)
|
||||||
|
return scheduler
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||||
|
|||||||
Reference in New Issue
Block a user