schedule fix
This commit is contained in:
@@ -42,8 +42,8 @@ class LRSchedule(object):
|
||||
logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
|
||||
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
||||
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
||||
warmup = max(warmup, 0)
|
||||
self.warmup, self.t_total = warmup, t_total
|
||||
warmup = max(warmup, 0.)
|
||||
self.warmup, self.t_total = float(warmup), float(t_total)
|
||||
self.warned_for_t_total_at_progress = -1
|
||||
|
||||
def get_lr(self, step, nowarn=False):
|
||||
@@ -153,7 +153,7 @@ class WarmupLinearSchedule(LRSchedule):
|
||||
def get_lr_(self, progress):
|
||||
if progress < self.warmup:
|
||||
return progress / self.warmup
|
||||
return max((progress - 1.) / (self.warmup - 1.), 0)
|
||||
return max((progress - 1.) / (self.warmup - 1.), 0.)
|
||||
|
||||
|
||||
SCHEDULES = {
|
||||
|
||||
@@ -51,7 +51,7 @@ class OptimizationTest(unittest.TestCase):
|
||||
|
||||
class WarmupCosineWithRestartsTest(unittest.TestCase):
|
||||
def test_it(self):
|
||||
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000, cycles=5)
|
||||
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5)
|
||||
x = np.arange(0, 1000)
|
||||
y = [m.get_lr(xe) for xe in x]
|
||||
# plt.plot(y)
|
||||
|
||||
Reference in New Issue
Block a user