[Trainer] Ability to specify optimizer/scheduler at init
cc @patrickvonplaten @thomwolf
This commit is contained in:
@@ -113,6 +113,7 @@ class Trainer:
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
|
||||
prediction_loss_only: bool
|
||||
tb_writer: Optional["SummaryWriter"] = None
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -124,6 +125,7 @@ class Trainer:
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
prediction_loss_only=False,
|
||||
tb_writer: Optional["SummaryWriter"] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None,
|
||||
):
|
||||
"""
|
||||
Trainer is a simple but feature-complete training and eval loop for PyTorch,
|
||||
@@ -143,6 +145,7 @@ class Trainer:
|
||||
self.eval_dataset = eval_dataset
|
||||
self.compute_metrics = compute_metrics
|
||||
self.prediction_loss_only = prediction_loss_only
|
||||
self.optimizers = optimizers
|
||||
if tb_writer is not None:
|
||||
self.tb_writer = tb_writer
|
||||
elif is_tensorboard_available() and self.args.local_rank in [-1, 0]:
|
||||
@@ -227,6 +230,15 @@ class Trainer:
|
||||
def get_optimizers(
|
||||
self, num_training_steps: int
|
||||
) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
|
||||
"""
|
||||
Setup the optimizer and the learning rate scheduler.
|
||||
|
||||
We provide a reasonable default that works well.
|
||||
If you want to use something else, you can pass a tuple in the Trainer's init,
|
||||
or override this method in a subclass.
|
||||
"""
|
||||
if self.optimizers is not None:
|
||||
return self.optimizers
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
|
||||
Reference in New Issue
Block a user