schedule fix
This commit is contained in:
@@ -42,8 +42,8 @@ class LRSchedule(object):
|
|||||||
logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
|
logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
|
||||||
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
||||||
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
||||||
warmup = max(warmup, 0)
|
warmup = max(warmup, 0.)
|
||||||
self.warmup, self.t_total = warmup, t_total
|
self.warmup, self.t_total = float(warmup), float(t_total)
|
||||||
self.warned_for_t_total_at_progress = -1
|
self.warned_for_t_total_at_progress = -1
|
||||||
|
|
||||||
def get_lr(self, step, nowarn=False):
|
def get_lr(self, step, nowarn=False):
|
||||||
@@ -153,7 +153,7 @@ class WarmupLinearSchedule(LRSchedule):
|
|||||||
def get_lr_(self, progress):
|
def get_lr_(self, progress):
|
||||||
if progress < self.warmup:
|
if progress < self.warmup:
|
||||||
return progress / self.warmup
|
return progress / self.warmup
|
||||||
return max((progress - 1.) / (self.warmup - 1.), 0)
|
return max((progress - 1.) / (self.warmup - 1.), 0.)
|
||||||
|
|
||||||
|
|
||||||
SCHEDULES = {
|
SCHEDULES = {
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class OptimizationTest(unittest.TestCase):
|
|||||||
|
|
||||||
class WarmupCosineWithRestartsTest(unittest.TestCase):
|
class WarmupCosineWithRestartsTest(unittest.TestCase):
|
||||||
def test_it(self):
|
def test_it(self):
|
||||||
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000, cycles=5)
|
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5)
|
||||||
x = np.arange(0, 1000)
|
x = np.arange(0, 1000)
|
||||||
y = [m.get_lr(xe) for xe in x]
|
y = [m.get_lr(xe) for xe in x]
|
||||||
# plt.plot(y)
|
# plt.plot(y)
|
||||||
|
|||||||
Reference in New Issue
Block a user