diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index 565d3bff45..8c8dc3b862 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -38,11 +38,12 @@ class LRSchedule(object): :param kw: """ super(LRSchedule, self).__init__(**kw) - self.warmup, self.t_total = warmup, t_total if t_total <= 0: logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) if not 0.0 <= warmup < 1.0 and not warmup == -1: raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) + warmup = max(warmup, 0) + self.warmup, self.t_total = warmup, t_total self.warned_for_t_total_at_progress = -1 def get_lr(self, step, nowarn=False): @@ -51,6 +52,8 @@ class LRSchedule(object): :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps :return: learning rate multiplier for current update """ + if self.t_total < 0: + return 1. progress = step / self.t_total ret = self.get_lr_(progress) # warning for exceeding t_total (only active with warmup_linear @@ -87,9 +90,6 @@ class WarmupCosineSchedule(LRSchedule): self.cycles = cycles def get_lr_(self, progress): - """ get learning rate multiplier """ - if self.t_total <= 0: - return 1. if progress < self.warmup: return progress / self.warmup else: @@ -106,8 +106,6 @@ class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule): assert(cycles >= 1.) def get_lr_(self, progress): - if self.t_total <= 0: - return 1. if progress < self.warmup: return progress / self.warmup else: @@ -124,11 +122,10 @@ class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedul """ def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): assert(warmup * cycles < 1.) - super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup*cycles, t_total=t_total, cycles=cycles, **kw) + warmup = warmup * cycles if warmup >= 0 else warmup + super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) def get_lr_(self, progress): - if self.t_total <= 0.: - return 1. progress = progress * self.cycles % 1. if progress < self.warmup: return progress / self.warmup @@ -174,7 +171,7 @@ class BertAdam(Optimizer): lr: learning rate warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 t_total: total number of training steps for the learning - rate schedule, -1 means constant learning rate. Default: -1 + rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1 schedule: schedule to use for the warmup (see above). Can be 'warmup_linear', 'warmup_constant', 'warmup_cosine', or a LRSchedule object. Default: 'warmup_linear' diff --git a/tests/optimization_test.py b/tests/optimization_test.py index 218da7581f..0eaae16d31 100644 --- a/tests/optimization_test.py +++ b/tests/optimization_test.py @@ -51,9 +51,9 @@ class OptimizationTest(unittest.TestCase): class WarmupCosineWithRestartsTest(unittest.TestCase): def test_it(self): - m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1, cycles=5) - x = np.arange(0, 1000) / 1000 - y = [m.get_lr_(xe) for xe in x] + m = WarmupCosineWithWarmupRestartsSchedule(warmup=-1, t_total=500, cycles=5) + x = np.arange(0, 1000) + y = [m.get_lr(xe) for xe in x] plt.plot(y) plt.show(block=False) y = np.asarray(y)