BertAdam schedule objects
This commit is contained in:
@@ -23,29 +23,99 @@ import logging
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def warmup_cosine(x, warmup=0.002):
|
|
||||||
if x < warmup:
|
|
||||||
return x/warmup
|
|
||||||
return 0.5 * (1.0 + torch.cos(math.pi * x))
|
|
||||||
|
|
||||||
def warmup_constant(x, warmup=0.002):
|
class LRSchedule(object):
|
||||||
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
|
warn_t_total = False
|
||||||
Learning rate is 1. afterwards. """
|
def __init__(self, warmup=0.002, t_total=-1, **kw):
|
||||||
if x < warmup:
|
super(LRSchedule, self).__init__(**kw)
|
||||||
return x/warmup
|
self.warmup, self.t_total = warmup, t_total
|
||||||
return 1.0
|
if t_total <= 0:
|
||||||
|
logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
|
||||||
|
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
||||||
|
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
||||||
|
self.warned_for_t_total_at_progress = -1
|
||||||
|
|
||||||
def warmup_linear(x, warmup=0.002):
|
def get_lr(self, step, nowarn=False):
|
||||||
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
|
progress = step / self.t_total
|
||||||
After `t_total`-th training step, learning rate is zero. """
|
ret = self.get_lr_(progress)
|
||||||
if x < warmup:
|
# warning for exceeding t_total (only active with warmup_linear
|
||||||
return x/warmup
|
if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress:
|
||||||
return max((x-1.)/(warmup-1.), 0)
|
logger.warning(
|
||||||
|
"Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly."
|
||||||
|
.format(ret, self.__class__.__name__))
|
||||||
|
self.warned_for_t_total_at_progress = progress
|
||||||
|
# end warning
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def get_lr_(self, step):
|
||||||
|
return 1.
|
||||||
|
# raise NotImplemented("use subclass")
|
||||||
|
|
||||||
|
|
||||||
|
class WarmupCosineSchedule(LRSchedule):
|
||||||
|
warn_t_total = True
|
||||||
|
def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw):
|
||||||
|
super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw)
|
||||||
|
self.cycles = cycles
|
||||||
|
|
||||||
|
def get_lr_(self, progress):
|
||||||
|
""" get learning rate multiplier """
|
||||||
|
if self.t_total <= 0:
|
||||||
|
return 1.
|
||||||
|
if progress < self.warmup:
|
||||||
|
return progress / self.warmup
|
||||||
|
else:
|
||||||
|
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
||||||
|
return 0.5 * (1. + torch.cos(math.pi * self.cycles * 2 * progress))
|
||||||
|
|
||||||
|
|
||||||
|
class WarmupConstantSchedule(LRSchedule):
|
||||||
|
warn_t_total = False
|
||||||
|
def get_lr_(self, progress):
|
||||||
|
if progress < self.warmup:
|
||||||
|
return progress / self.warmup
|
||||||
|
return 1.
|
||||||
|
|
||||||
|
|
||||||
|
class WarmupLinearSchedule(LRSchedule):
|
||||||
|
warn_t_total = True
|
||||||
|
def get_lr_(self, progress):
|
||||||
|
if progress < self.warmup:
|
||||||
|
return progress / self.warmup
|
||||||
|
return max((progress - 1.) / (self.warmup - 1.), 0)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def warmup_cosine(x, warmup=0.002):
|
||||||
|
# if x < warmup:
|
||||||
|
# return x/warmup
|
||||||
|
# return 0.5 * (1.0 + torch.cos(math.pi * x))
|
||||||
|
#
|
||||||
|
# def warmup_constant(x, warmup=0.002):
|
||||||
|
# """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
|
||||||
|
# Learning rate is 1. afterwards. """
|
||||||
|
# if x < warmup:
|
||||||
|
# return x/warmup
|
||||||
|
# return 1.0
|
||||||
|
#
|
||||||
|
# def warmup_linear(x, warmup=0.002):
|
||||||
|
# """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
|
||||||
|
# After `t_total`-th training step, learning rate is zero. """
|
||||||
|
# if x < warmup:
|
||||||
|
# return x/warmup
|
||||||
|
# return max((x-1.)/(warmup-1.), 0)
|
||||||
|
#
|
||||||
|
# SCHEDULES = {
|
||||||
|
# 'warmup_cosine': warmup_cosine,
|
||||||
|
# 'warmup_constant': warmup_constant,
|
||||||
|
# 'warmup_linear': warmup_linear,
|
||||||
|
# }
|
||||||
|
|
||||||
SCHEDULES = {
|
SCHEDULES = {
|
||||||
'warmup_cosine': warmup_cosine,
|
None: LRSchedule,
|
||||||
'warmup_constant': warmup_constant,
|
"none": LRSchedule,
|
||||||
'warmup_linear': warmup_linear,
|
"warmup_cosine": WarmupCosineSchedule,
|
||||||
|
"warmup_constant": WarmupConstantSchedule,
|
||||||
|
"warmup_linear": WarmupLinearSchedule
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -70,15 +140,16 @@ class BertAdam(Optimizer):
|
|||||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||||
if schedule not in SCHEDULES:
|
if schedule not in SCHEDULES:
|
||||||
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
||||||
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
|
||||||
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
|
||||||
if not 0.0 <= b1 < 1.0:
|
if not 0.0 <= b1 < 1.0:
|
||||||
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
||||||
if not 0.0 <= b2 < 1.0:
|
if not 0.0 <= b2 < 1.0:
|
||||||
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
|
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
|
||||||
if not e >= 0.0:
|
if not e >= 0.0:
|
||||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
||||||
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
|
# initialize schedule object
|
||||||
|
schedule_type = SCHEDULES[schedule]
|
||||||
|
sched = schedule_type(warmup=warmup, t_total=t_total)
|
||||||
|
defaults = dict(lr=lr, schedule=sched,
|
||||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
||||||
max_grad_norm=max_grad_norm)
|
max_grad_norm=max_grad_norm)
|
||||||
super(BertAdam, self).__init__(params, defaults)
|
super(BertAdam, self).__init__(params, defaults)
|
||||||
@@ -90,11 +161,10 @@ class BertAdam(Optimizer):
|
|||||||
state = self.state[p]
|
state = self.state[p]
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
return [0]
|
return [0]
|
||||||
if group['t_total'] != -1:
|
|
||||||
schedule_fct = SCHEDULES[group['schedule']]
|
|
||||||
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
|
|
||||||
else:
|
|
||||||
lr_scheduled = group['lr']
|
lr_scheduled = group['lr']
|
||||||
|
lr_scheduled *= group['schedule'](state['step'])
|
||||||
|
|
||||||
lr.append(lr_scheduled)
|
lr.append(lr_scheduled)
|
||||||
return lr
|
return lr
|
||||||
|
|
||||||
@@ -109,8 +179,6 @@ class BertAdam(Optimizer):
|
|||||||
if closure is not None:
|
if closure is not None:
|
||||||
loss = closure()
|
loss = closure()
|
||||||
|
|
||||||
warned_for_t_total = False
|
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
@@ -152,19 +220,8 @@ class BertAdam(Optimizer):
|
|||||||
if group['weight_decay'] > 0.0:
|
if group['weight_decay'] > 0.0:
|
||||||
update += group['weight_decay'] * p.data
|
update += group['weight_decay'] * p.data
|
||||||
|
|
||||||
if group['t_total'] != -1:
|
|
||||||
schedule_fct = SCHEDULES[group['schedule']]
|
|
||||||
progress = state['step']/group['t_total']
|
|
||||||
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
|
|
||||||
# warning for exceeding t_total (only active with warmup_linear
|
|
||||||
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
|
|
||||||
logger.warning(
|
|
||||||
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
|
|
||||||
"Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
|
|
||||||
warned_for_t_total = True
|
|
||||||
# end warning
|
|
||||||
else:
|
|
||||||
lr_scheduled = group['lr']
|
lr_scheduled = group['lr']
|
||||||
|
lr_scheduled *= group['schedule'](state['step'])
|
||||||
|
|
||||||
update_with_lr = lr_scheduled * update
|
update_with_lr = lr_scheduled * update
|
||||||
p.data.add_(-update_with_lr)
|
p.data.add_(-update_with_lr)
|
||||||
|
|||||||
Reference in New Issue
Block a user