From cafa6a9e29f3e99c67a1028f8ca779d439bc0689 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Thu, 7 May 2020 11:25:26 -0400 Subject: [PATCH] [Trainer] Ability to specify optimizer/scheduler at init cc @patrickvonplaten @thomwolf --- src/transformers/trainer.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index af54330776..3ee9bdd843 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -113,6 +113,7 @@ class Trainer: compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None prediction_loss_only: bool tb_writer: Optional["SummaryWriter"] = None + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None def __init__( self, @@ -124,6 +125,7 @@ class Trainer: compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, prediction_loss_only=False, tb_writer: Optional["SummaryWriter"] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None, ): """ Trainer is a simple but feature-complete training and eval loop for PyTorch, @@ -143,6 +145,7 @@ class Trainer: self.eval_dataset = eval_dataset self.compute_metrics = compute_metrics self.prediction_loss_only = prediction_loss_only + self.optimizers = optimizers if tb_writer is not None: self.tb_writer = tb_writer elif is_tensorboard_available() and self.args.local_rank in [-1, 0]: @@ -227,6 +230,15 @@ class Trainer: def get_optimizers( self, num_training_steps: int ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]: + """ + Setup the optimizer and the learning rate scheduler. + + We provide a reasonable default that works well. + If you want to use something else, you can pass a tuple in the Trainer's init, + or override this method in a subclass. + """ + if self.optimizers is not None: + return self.optimizers # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [