[Trainer] Ability to specify optimizer/scheduler at init
cc @patrickvonplaten @thomwolf
This commit is contained in:
@@ -113,6 +113,7 @@ class Trainer:
|
|||||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
|
||||||
prediction_loss_only: bool
|
prediction_loss_only: bool
|
||||||
tb_writer: Optional["SummaryWriter"] = None
|
tb_writer: Optional["SummaryWriter"] = None
|
||||||
|
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -124,6 +125,7 @@ class Trainer:
|
|||||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||||
prediction_loss_only=False,
|
prediction_loss_only=False,
|
||||||
tb_writer: Optional["SummaryWriter"] = None,
|
tb_writer: Optional["SummaryWriter"] = None,
|
||||||
|
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Trainer is a simple but feature-complete training and eval loop for PyTorch,
|
Trainer is a simple but feature-complete training and eval loop for PyTorch,
|
||||||
@@ -143,6 +145,7 @@ class Trainer:
|
|||||||
self.eval_dataset = eval_dataset
|
self.eval_dataset = eval_dataset
|
||||||
self.compute_metrics = compute_metrics
|
self.compute_metrics = compute_metrics
|
||||||
self.prediction_loss_only = prediction_loss_only
|
self.prediction_loss_only = prediction_loss_only
|
||||||
|
self.optimizers = optimizers
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
self.tb_writer = tb_writer
|
self.tb_writer = tb_writer
|
||||||
elif is_tensorboard_available() and self.args.local_rank in [-1, 0]:
|
elif is_tensorboard_available() and self.args.local_rank in [-1, 0]:
|
||||||
@@ -227,6 +230,15 @@ class Trainer:
|
|||||||
def get_optimizers(
|
def get_optimizers(
|
||||||
self, num_training_steps: int
|
self, num_training_steps: int
|
||||||
) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
|
) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
|
||||||
|
"""
|
||||||
|
Setup the optimizer and the learning rate scheduler.
|
||||||
|
|
||||||
|
We provide a reasonable default that works well.
|
||||||
|
If you want to use something else, you can pass a tuple in the Trainer's init,
|
||||||
|
or override this method in a subclass.
|
||||||
|
"""
|
||||||
|
if self.optimizers is not None:
|
||||||
|
return self.optimizers
|
||||||
# Prepare optimizer and schedule (linear warmup and decay)
|
# Prepare optimizer and schedule (linear warmup and decay)
|
||||||
no_decay = ["bias", "LayerNorm.weight"]
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
|
|||||||
Reference in New Issue
Block a user