class weights

This commit is contained in:
lukovnikov
2019-03-18 18:29:12 +01:00
parent b6c1cae67b
commit 262a9992d7
2 changed files with 32 additions and 4 deletions

View File

@@ -24,7 +24,8 @@ import logging
logger = logging.getLogger(__name__)
__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam", "WarmupCosineWithRestartsSchedule"]
__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam",
"WarmupMultiCosineSchedule", "WarmupCosineWithRestartsSchedule"]
class LRSchedule(object):
@@ -72,10 +73,11 @@ class WarmupCosineSchedule(LRSchedule):
return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
class WarmupCosineWithRestartsSchedule(WarmupCosineSchedule):
class WarmupMultiCosineSchedule(WarmupCosineSchedule):
warn_t_total = True
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
super(WarmupCosineWithRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
super(WarmupMultiCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
assert(cycles >= 1.)
def get_lr_(self, progress):
if self.t_total <= 0:
@@ -88,6 +90,19 @@ class WarmupCosineWithRestartsSchedule(WarmupCosineSchedule):
return ret
class WarmupCosineWithRestartsSchedule(WarmupMultiCosineSchedule):
def get_lr_(self, progress):
if self.t_total <= 0.:
return 1.
progress = progress * self.cycles % 1.
if progress < self.warmup:
return progress / self.warmup
else:
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
ret = 0.5 * (1. + math.cos(math.pi * progress))
return ret
class WarmupConstantSchedule(LRSchedule):
warn_t_total = False
def get_lr_(self, progress):