diff --git a/optimization.py b/optimization.py index 91701e4255..e04f01b1d2 100644 --- a/optimization.py +++ b/optimization.py @@ -42,17 +42,18 @@ SCHEDULES = { class BERTAdam(Optimizer): - """Implements Open AI version of Adam algorithm with weight decay fix. + """Implements BERT version of Adam algorithm with weight decay fix (and no ). Params: - lr, - warmup=-1, - t_total=-1, - schedule='warmup_linear', - b1=0.9, - b2=0.999, - e=1e-6, - weight_decay_rate=0.01, - max_grad_norm=1.0 + lr: learning rate + warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 + t_total: total number of training steps for the learning + rate schedule, -1 means constant learning rate. Default: -1 + schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' + b1: Adams b1. Default: 0.9 + b2: Adams b2. Default: 0.999 + e: Adams epsilon. Default: 1e-6 + weight_decay_rate: Weight decay. Default: 0.01 + max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 """ def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear', b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01,