[Trainer] Ability to specify optimizer/scheduler at init

cc @patrickvonplaten @thomwolf
This commit is contained in:
Julien Chaumond
2020-05-07 11:25:26 -04:00
parent e4fd5e3999
commit cafa6a9e29

View File

@@ -113,6 +113,7 @@ class Trainer:
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
prediction_loss_only: bool
tb_writer: Optional["SummaryWriter"] = None
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None
def __init__(
self,
@@ -124,6 +125,7 @@ class Trainer:
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
prediction_loss_only=False,
tb_writer: Optional["SummaryWriter"] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None,
):
"""
Trainer is a simple but feature-complete training and eval loop for PyTorch,
@@ -143,6 +145,7 @@ class Trainer:
self.eval_dataset = eval_dataset
self.compute_metrics = compute_metrics
self.prediction_loss_only = prediction_loss_only
self.optimizers = optimizers
if tb_writer is not None:
self.tb_writer = tb_writer
elif is_tensorboard_available() and self.args.local_rank in [-1, 0]:
@@ -227,6 +230,15 @@ class Trainer:
def get_optimizers(
self, num_training_steps: int
) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
"""
Setup the optimizer and the learning rate scheduler.
We provide a reasonable default that works well.
If you want to use something else, you can pass a tuple in the Trainer's init,
or override this method in a subclass.
"""
if self.optimizers is not None:
return self.optimizers
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [