restart cosine lr schedule
This commit is contained in:
@@ -69,7 +69,23 @@ class WarmupCosineSchedule(LRSchedule):
|
|||||||
return progress / self.warmup
|
return progress / self.warmup
|
||||||
else:
|
else:
|
||||||
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
||||||
return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
|
return 0.5 * (1. + math.cos(math.pi * ((self.cycles * 2 * progress) % 1))
|
||||||
|
|
||||||
|
|
||||||
|
class WarmupCosineWithRestartsSchedule(WarmupCosineSchedule):
|
||||||
|
warn_t_total = True
|
||||||
|
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
|
||||||
|
super(WarmupCosineWithRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
|
||||||
|
|
||||||
|
def get_lr_(self, progress):
|
||||||
|
if self.t_total <= 0:
|
||||||
|
return 1.
|
||||||
|
if progress < self.warmup:
|
||||||
|
return progress / self.warmup
|
||||||
|
else:
|
||||||
|
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
||||||
|
ret = 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class WarmupConstantSchedule(LRSchedule):
|
class WarmupConstantSchedule(LRSchedule):
|
||||||
|
|||||||
Reference in New Issue
Block a user