From 88874f6cf09e14fc482abc186adebb2767dca258 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Fri, 8 Mar 2019 19:08:30 +0100 Subject: [PATCH] BertAdam schedule objects --- pytorch_pretrained_bert/optimization.py | 141 +++++++++++++++++------- 1 file changed, 99 insertions(+), 42 deletions(-) diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index fa911e5c04..73afc71058 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -23,29 +23,99 @@ import logging logger = logging.getLogger(__name__) -def warmup_cosine(x, warmup=0.002): - if x < warmup: - return x/warmup - return 0.5 * (1.0 + torch.cos(math.pi * x)) -def warmup_constant(x, warmup=0.002): - """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. - Learning rate is 1. afterwards. """ - if x < warmup: - return x/warmup - return 1.0 +class LRSchedule(object): + warn_t_total = False + def __init__(self, warmup=0.002, t_total=-1, **kw): + super(LRSchedule, self).__init__(**kw) + self.warmup, self.t_total = warmup, t_total + if t_total <= 0: + logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) + if not 0.0 <= warmup < 1.0 and not warmup == -1: + raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) + self.warned_for_t_total_at_progress = -1 -def warmup_linear(x, warmup=0.002): - """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. - After `t_total`-th training step, learning rate is zero. """ - if x < warmup: - return x/warmup - return max((x-1.)/(warmup-1.), 0) + def get_lr(self, step, nowarn=False): + progress = step / self.t_total + ret = self.get_lr_(progress) + # warning for exceeding t_total (only active with warmup_linear + if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress: + logger.warning( + "Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly." + .format(ret, self.__class__.__name__)) + self.warned_for_t_total_at_progress = progress + # end warning + return ret + + def get_lr_(self, step): + return 1. + # raise NotImplemented("use subclass") + + +class WarmupCosineSchedule(LRSchedule): + warn_t_total = True + def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw): + super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw) + self.cycles = cycles + + def get_lr_(self, progress): + """ get learning rate multiplier """ + if self.t_total <= 0: + return 1. + if progress < self.warmup: + return progress / self.warmup + else: + progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup + return 0.5 * (1. + torch.cos(math.pi * self.cycles * 2 * progress)) + + +class WarmupConstantSchedule(LRSchedule): + warn_t_total = False + def get_lr_(self, progress): + if progress < self.warmup: + return progress / self.warmup + return 1. + + +class WarmupLinearSchedule(LRSchedule): + warn_t_total = True + def get_lr_(self, progress): + if progress < self.warmup: + return progress / self.warmup + return max((progress - 1.) / (self.warmup - 1.), 0) +# +# +# def warmup_cosine(x, warmup=0.002): +# if x < warmup: +# return x/warmup +# return 0.5 * (1.0 + torch.cos(math.pi * x)) +# +# def warmup_constant(x, warmup=0.002): +# """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. +# Learning rate is 1. afterwards. """ +# if x < warmup: +# return x/warmup +# return 1.0 +# +# def warmup_linear(x, warmup=0.002): +# """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. +# After `t_total`-th training step, learning rate is zero. """ +# if x < warmup: +# return x/warmup +# return max((x-1.)/(warmup-1.), 0) +# +# SCHEDULES = { +# 'warmup_cosine': warmup_cosine, +# 'warmup_constant': warmup_constant, +# 'warmup_linear': warmup_linear, +# } SCHEDULES = { - 'warmup_cosine': warmup_cosine, - 'warmup_constant': warmup_constant, - 'warmup_linear': warmup_linear, + None: LRSchedule, + "none": LRSchedule, + "warmup_cosine": WarmupCosineSchedule, + "warmup_constant": WarmupConstantSchedule, + "warmup_linear": WarmupLinearSchedule } @@ -70,15 +140,16 @@ class BertAdam(Optimizer): raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) if schedule not in SCHEDULES: raise ValueError("Invalid schedule parameter: {}".format(schedule)) - if not 0.0 <= warmup < 1.0 and not warmup == -1: - raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) if not 0.0 <= b1 < 1.0: raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) if not 0.0 <= b2 < 1.0: raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) if not e >= 0.0: raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) - defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, + # initialize schedule object + schedule_type = SCHEDULES[schedule] + sched = schedule_type(warmup=warmup, t_total=t_total) + defaults = dict(lr=lr, schedule=sched, b1=b1, b2=b2, e=e, weight_decay=weight_decay, max_grad_norm=max_grad_norm) super(BertAdam, self).__init__(params, defaults) @@ -90,11 +161,10 @@ class BertAdam(Optimizer): state = self.state[p] if len(state) == 0: return [0] - if group['t_total'] != -1: - schedule_fct = SCHEDULES[group['schedule']] - lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) - else: - lr_scheduled = group['lr'] + + lr_scheduled = group['lr'] + lr_scheduled *= group['schedule'](state['step']) + lr.append(lr_scheduled) return lr @@ -109,8 +179,6 @@ class BertAdam(Optimizer): if closure is not None: loss = closure() - warned_for_t_total = False - for group in self.param_groups: for p in group['params']: if p.grad is None: @@ -152,19 +220,8 @@ class BertAdam(Optimizer): if group['weight_decay'] > 0.0: update += group['weight_decay'] * p.data - if group['t_total'] != -1: - schedule_fct = SCHEDULES[group['schedule']] - progress = state['step']/group['t_total'] - lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) - # warning for exceeding t_total (only active with warmup_linear - if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total: - logger.warning( - "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. " - "Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__)) - warned_for_t_total = True - # end warning - else: - lr_scheduled = group['lr'] + lr_scheduled = group['lr'] + lr_scheduled *= group['schedule'](state['step']) update_with_lr = lr_scheduled * update p.data.add_(-update_with_lr)