From baf66d141958785feb0dfc90d6cd8558eb95a774 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Tue, 12 Mar 2019 13:22:23 +0100 Subject: [PATCH] restart cosine lr schedule --- pytorch_pretrained_bert/optimization.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index a92adb4c56..58e16f01a6 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -69,7 +69,23 @@ class WarmupCosineSchedule(LRSchedule): return progress / self.warmup else: progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup - return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) + return 0.5 * (1. + math.cos(math.pi * ((self.cycles * 2 * progress) % 1)) + + +class WarmupCosineWithRestartsSchedule(WarmupCosineSchedule): + warn_t_total = True + def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): + super(WarmupCosineWithRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) + + def get_lr_(self, progress): + if self.t_total <= 0: + return 1. + if progress < self.warmup: + return progress / self.warmup + else: + progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup + ret = 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) + return ret class WarmupConstantSchedule(LRSchedule):