added warning

This commit is contained in:
lukovnikov
2019-02-27 16:10:31 +01:00
parent 60a372387f
commit 9bc3773c84

View File

@@ -35,19 +35,28 @@ def warmup_constant(x, warmup=0.002):
return x/warmup
return 1.0
class Warmup_Linear_with_Warning(object):
def __init__(self, **kw):
super(Warmup_Linear_with_Warning, self).__init__()
self.warned_at_x = -1
def __call__(self, x, warmup=0.002):
if x > 1 and x > self.warned_at_x:
logger.warning("Training beyond specified 't_total' steps. Learning rate set to zero. Please set 't_total' of BertAdam correctly.")
self.warned_at_x = x
return warmup_linear(x, warmup=warmup)
def warmup_linear(x, warmup=0.002):
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
After `t_total`-th training step, learning rate is zero. """
if x < warmup:
return x/warmup
if x > 1:
logger.warning("Training beyond specified 't_total' steps. Learning rate set to zero. Please set 't_total' of BertAdam correctly.")
return max((x-1.)/(warmup-1.), 0)
SCHEDULES = {
'warmup_cosine':warmup_cosine,
'warmup_constant':warmup_constant,
'warmup_linear':warmup_linear,
'warmup_linear': Warmup_Linear_with_Warning(), #warmup_linear,
}