From 0d07a23c04c9837234cda402b32246d7581e5bc4 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 1 Nov 2019 15:07:01 +0000 Subject: [PATCH] LAMB implementation --- transformers/optimization.py | 93 ++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/transformers/optimization.py b/transformers/optimization.py index 99e6cc75e4..90f3dbca9b 100644 --- a/transformers/optimization.py +++ b/transformers/optimization.py @@ -167,3 +167,96 @@ class AdamW(Optimizer): p.data.add_(-group['lr'] * group['weight_decay'], p.data) return loss + + + +class Lamb(Optimizer): + """ Implements the LAMB algorithm (Layer-wise Adaptive Moments optimizer for Batch training). + + Adapted from the huggingface/transformers ADAM optimizer + Inspired from the Google Research implementation available in ALBERT: https://github.com/google-research/google-research/blob/master/albert/lamb_optimizer.py + Inspired from cybertronai's PyTorch LAMB implementation: https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py + + + Parameters: + lr (float): learning rate. Default 1e-3. + betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) + eps (float): Adams epsilon. Default: 1e-6 + weight_decay (float): Weight decay. Default: 0.0 + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True): + if lr < 0.0: + raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, + correct_bias=correct_bias) + super(Lamb, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('LAMB does not support sparse gradients.') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(1.0 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) + denom = exp_avg_sq.sqrt().add_(group['eps']) + + + # Inspired from cybertronai's PyTorch LAMB implementation: https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py + step_size = group['lr'] + weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) + + adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) + if group['weight_decay'] != 0: + adam_step.add_(group['weight_decay'], p.data) + + adam_norm = adam_step.pow(2).sum().sqrt() + if weight_norm == 0 or adam_norm == 0: + trust_ratio = 1 + else: + trust_ratio = weight_norm / adam_norm + + + state['weight_norm'] = weight_norm + state['adam_norm'] = adam_norm + state['trust_ratio'] = trust_ratio + + p.data.add_(-step_size * trust_ratio, adam_step) + return loss