- updated docs for optimization

This commit is contained in:
lukovnikov
2019-04-03 16:08:34 +02:00
parent 725a56329d
commit 1758c8fc72
2 changed files with 70 additions and 69 deletions

View File

@@ -25,12 +25,18 @@ logger = logging.getLogger(__name__)
__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam",
"WarmupMultiCosineSchedule", "WarmupCosineWithRestartsSchedule"]
"WarmupCosineWithHardRestartsSchedule", "WarmupCosineWithWarmupRestartsSchedule", "SCHEDULES"]
class LRSchedule(object):
warn_t_total = False
""" Parent of all LRSchedules here. """
warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
def __init__(self, warmup=0.002, t_total=-1, **kw):
"""
:param warmup: what fraction of t_total steps will be used for linear warmup
:param t_total: how many training steps (updates) are planned
:param kw:
"""
super(LRSchedule, self).__init__(**kw)
self.warmup, self.t_total = warmup, t_total
if t_total <= 0:
@@ -40,6 +46,11 @@ class LRSchedule(object):
self.warned_for_t_total_at_progress = -1
def get_lr(self, step, nowarn=False):
"""
:param step: which of t_total steps we're on
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
:return: learning rate multiplier for current update
"""
progress = step / self.t_total
ret = self.get_lr_(progress)
# warning for exceeding t_total (only active with warmup_linear
@@ -51,14 +62,27 @@ class LRSchedule(object):
# end warning
return ret
def get_lr_(self, step):
def get_lr_(self, progress):
"""
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
:return: learning rate multiplier for current update
"""
return 1.
# raise NotImplemented("use subclass") -
class WarmupCosineSchedule(LRSchedule):
"""
Cosine learning rate schedule with linear warmup. Cosine after warmup is without restarts.
"""
warn_t_total = True
def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw):
"""
:param warmup: see LRSchedule
:param t_total: see LRSchedule
:param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
:param kw:
"""
super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw)
self.cycles = cycles
@@ -73,10 +97,12 @@ class WarmupCosineSchedule(LRSchedule):
return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
class WarmupMultiCosineSchedule(WarmupCosineSchedule):
warn_t_total = True
class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
"""
Cosine learning rate schedule with linear warmup and hard restarts (if cycles > 1).
"""
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
super(WarmupMultiCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
assert(cycles >= 1.)
def get_lr_(self, progress):
@@ -90,7 +116,16 @@ class WarmupMultiCosineSchedule(WarmupCosineSchedule):
return ret
class WarmupCosineWithRestartsSchedule(WarmupMultiCosineSchedule):
class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule):
"""
Cosine learning rate schedule with linear warmups and linear warmup restarts.
The same warmup rate is used for warmup restarts as for initial warmup.
The total effective fraction of warmup steps over all cycles is warmup * cycles!
"""
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
assert(warmup * cycles < 1.)
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup*cycles, t_total=t_total, cycles=cycles, **kw)
def get_lr_(self, progress):
if self.t_total <= 0.:
return 1.
@@ -104,7 +139,9 @@ class WarmupCosineWithRestartsSchedule(WarmupMultiCosineSchedule):
class WarmupConstantSchedule(LRSchedule):
warn_t_total = False
"""
Applies linear warmup. After warmup always returns 1..
"""
def get_lr_(self, progress):
if progress < self.warmup:
return progress / self.warmup
@@ -112,6 +149,9 @@ class WarmupConstantSchedule(LRSchedule):
class WarmupLinearSchedule(LRSchedule):
"""
Linear warmup. Linear decay after warmup.
"""
warn_t_total = True
def get_lr_(self, progress):
if progress < self.warmup:
@@ -145,8 +185,7 @@ class BertAdam(Optimizer):
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
"""
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, init_weight_decay=0.,
max_grad_norm=1.0):
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES:
@@ -163,9 +202,10 @@ class BertAdam(Optimizer):
schedule = schedule_type(warmup=warmup, t_total=t_total)
else:
if warmup != -1 or t_total != -1:
logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided.")
logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided. "
"Please specify custom warmup and t_total in LRSchedule object.")
defaults = dict(lr=lr, schedule=schedule,
b1=b1, b2=b2, e=e, weight_decay=weight_decay, init_weight_decay=init_weight_decay,
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(BertAdam, self).__init__(params, defaults)
@@ -176,10 +216,8 @@ class BertAdam(Optimizer):
state = self.state[p]
if len(state) == 0:
return [0]
lr_scheduled = group['lr']
lr_scheduled *= group['schedule'].get_lr(state['step'])
lr.append(lr_scheduled)
return lr
@@ -235,8 +273,6 @@ class BertAdam(Optimizer):
if group['weight_decay'] > 0.0:
update += group['weight_decay'] * p.data
# TODO: init weight decay
lr_scheduled = group['lr']
lr_scheduled *= group['schedule'].get_lr(state['step'])