schedule fix
This commit is contained in:
@@ -38,11 +38,12 @@ class LRSchedule(object):
|
|||||||
:param kw:
|
:param kw:
|
||||||
"""
|
"""
|
||||||
super(LRSchedule, self).__init__(**kw)
|
super(LRSchedule, self).__init__(**kw)
|
||||||
self.warmup, self.t_total = warmup, t_total
|
|
||||||
if t_total <= 0:
|
if t_total <= 0:
|
||||||
logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
|
logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
|
||||||
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
||||||
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
||||||
|
warmup = max(warmup, 0)
|
||||||
|
self.warmup, self.t_total = warmup, t_total
|
||||||
self.warned_for_t_total_at_progress = -1
|
self.warned_for_t_total_at_progress = -1
|
||||||
|
|
||||||
def get_lr(self, step, nowarn=False):
|
def get_lr(self, step, nowarn=False):
|
||||||
@@ -51,6 +52,8 @@ class LRSchedule(object):
|
|||||||
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
|
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
|
||||||
:return: learning rate multiplier for current update
|
:return: learning rate multiplier for current update
|
||||||
"""
|
"""
|
||||||
|
if self.t_total < 0:
|
||||||
|
return 1.
|
||||||
progress = step / self.t_total
|
progress = step / self.t_total
|
||||||
ret = self.get_lr_(progress)
|
ret = self.get_lr_(progress)
|
||||||
# warning for exceeding t_total (only active with warmup_linear
|
# warning for exceeding t_total (only active with warmup_linear
|
||||||
@@ -87,9 +90,6 @@ class WarmupCosineSchedule(LRSchedule):
|
|||||||
self.cycles = cycles
|
self.cycles = cycles
|
||||||
|
|
||||||
def get_lr_(self, progress):
|
def get_lr_(self, progress):
|
||||||
""" get learning rate multiplier """
|
|
||||||
if self.t_total <= 0:
|
|
||||||
return 1.
|
|
||||||
if progress < self.warmup:
|
if progress < self.warmup:
|
||||||
return progress / self.warmup
|
return progress / self.warmup
|
||||||
else:
|
else:
|
||||||
@@ -106,8 +106,6 @@ class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
|
|||||||
assert(cycles >= 1.)
|
assert(cycles >= 1.)
|
||||||
|
|
||||||
def get_lr_(self, progress):
|
def get_lr_(self, progress):
|
||||||
if self.t_total <= 0:
|
|
||||||
return 1.
|
|
||||||
if progress < self.warmup:
|
if progress < self.warmup:
|
||||||
return progress / self.warmup
|
return progress / self.warmup
|
||||||
else:
|
else:
|
||||||
@@ -124,11 +122,10 @@ class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedul
|
|||||||
"""
|
"""
|
||||||
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
|
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
|
||||||
assert(warmup * cycles < 1.)
|
assert(warmup * cycles < 1.)
|
||||||
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup*cycles, t_total=t_total, cycles=cycles, **kw)
|
warmup = warmup * cycles if warmup >= 0 else warmup
|
||||||
|
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
|
||||||
|
|
||||||
def get_lr_(self, progress):
|
def get_lr_(self, progress):
|
||||||
if self.t_total <= 0.:
|
|
||||||
return 1.
|
|
||||||
progress = progress * self.cycles % 1.
|
progress = progress * self.cycles % 1.
|
||||||
if progress < self.warmup:
|
if progress < self.warmup:
|
||||||
return progress / self.warmup
|
return progress / self.warmup
|
||||||
@@ -174,7 +171,7 @@ class BertAdam(Optimizer):
|
|||||||
lr: learning rate
|
lr: learning rate
|
||||||
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
|
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
|
||||||
t_total: total number of training steps for the learning
|
t_total: total number of training steps for the learning
|
||||||
rate schedule, -1 means constant learning rate. Default: -1
|
rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1
|
||||||
schedule: schedule to use for the warmup (see above).
|
schedule: schedule to use for the warmup (see above).
|
||||||
Can be 'warmup_linear', 'warmup_constant', 'warmup_cosine', or a LRSchedule object.
|
Can be 'warmup_linear', 'warmup_constant', 'warmup_cosine', or a LRSchedule object.
|
||||||
Default: 'warmup_linear'
|
Default: 'warmup_linear'
|
||||||
|
|||||||
@@ -51,9 +51,9 @@ class OptimizationTest(unittest.TestCase):
|
|||||||
|
|
||||||
class WarmupCosineWithRestartsTest(unittest.TestCase):
|
class WarmupCosineWithRestartsTest(unittest.TestCase):
|
||||||
def test_it(self):
|
def test_it(self):
|
||||||
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1, cycles=5)
|
m = WarmupCosineWithWarmupRestartsSchedule(warmup=-1, t_total=500, cycles=5)
|
||||||
x = np.arange(0, 1000) / 1000
|
x = np.arange(0, 1000)
|
||||||
y = [m.get_lr_(xe) for xe in x]
|
y = [m.get_lr(xe) for xe in x]
|
||||||
plt.plot(y)
|
plt.plot(y)
|
||||||
plt.show(block=False)
|
plt.show(block=False)
|
||||||
y = np.asarray(y)
|
y = np.asarray(y)
|
||||||
|
|||||||
Reference in New Issue
Block a user