[s2s] configure lr_scheduler from command line (#7641)

This commit is contained in:
Suraj Patil
2020-10-08 22:36:35 +05:30
committed by GitHub
parent 4a00613c24
commit 06a973fd2a
2 changed files with 37 additions and 3 deletions

View File

@@ -8,7 +8,16 @@ from torch.utils.data import DistributedSampler, RandomSampler
from transformers import Trainer
from transformers.configuration_fsmt import FSMTConfig
from transformers.file_utils import is_torch_tpu_available
from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup
from transformers.optimization import (
Adafactor,
AdamW,
get_constant_schedule,
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
)
from transformers.trainer_pt_utils import get_tpu_sampler
@@ -20,6 +29,16 @@ except ImportError:
logger = logging.getLogger(__name__)
arg_to_scheduler = {
"linear": get_linear_schedule_with_warmup,
"cosine": get_cosine_schedule_with_warmup,
"cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
"polynomial": get_polynomial_decay_schedule_with_warmup,
"constant": get_constant_schedule,
"constant_w_warmup": get_constant_schedule_with_warmup,
}
arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
class Seq2SeqTrainer(Trainer):
def __init__(self, config, data_args, *args, **kwargs):
@@ -62,9 +81,21 @@ class Seq2SeqTrainer(Trainer):
)
if self.lr_scheduler is None:
self.lr_scheduler = get_linear_schedule_with_warmup(
self.lr_scheduler = self._get_lr_scheduler(num_training_steps)
else: # ignoring --lr_scheduler
logger.warn("scheduler is passed to `Seq2SeqTrainer`, `--lr_scheduler` arg is ignored.")
def _get_lr_scheduler(self, num_training_steps):
schedule_func = arg_to_scheduler[self.args.lr_scheduler]
if self.args.lr_scheduler == "constant":
scheduler = schedule_func(self.optimizer)
elif self.args.lr_scheduler == "constant_w_warmup":
scheduler = schedule_func(self.optimizer, num_warmup_steps=self.args.warmup_steps)
else:
scheduler = schedule_func(
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
)
return scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):