From 20686b78fc786bf662b4ed1bd743823aeef57fd8 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Wed, 3 Apr 2019 18:13:52 +0200 Subject: [PATCH] schedule fix --- pytorch_pretrained_bert/optimization.py | 6 +++--- tests/optimization_test.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index 92cf2b05eb..df5b50b51d 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -42,8 +42,8 @@ class LRSchedule(object): logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) if not 0.0 <= warmup < 1.0 and not warmup == -1: raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) - warmup = max(warmup, 0) - self.warmup, self.t_total = warmup, t_total + warmup = max(warmup, 0.) + self.warmup, self.t_total = float(warmup), float(t_total) self.warned_for_t_total_at_progress = -1 def get_lr(self, step, nowarn=False): @@ -153,7 +153,7 @@ class WarmupLinearSchedule(LRSchedule): def get_lr_(self, progress): if progress < self.warmup: return progress / self.warmup - return max((progress - 1.) / (self.warmup - 1.), 0) + return max((progress - 1.) / (self.warmup - 1.), 0.) SCHEDULES = { diff --git a/tests/optimization_test.py b/tests/optimization_test.py index 80216cc8d4..e74f4bba6c 100644 --- a/tests/optimization_test.py +++ b/tests/optimization_test.py @@ -51,7 +51,7 @@ class OptimizationTest(unittest.TestCase): class WarmupCosineWithRestartsTest(unittest.TestCase): def test_it(self): - m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000, cycles=5) + m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5) x = np.arange(0, 1000) y = [m.get_lr(xe) for xe in x] # plt.plot(y)