added warning
This commit is contained in:
@@ -35,19 +35,28 @@ def warmup_constant(x, warmup=0.002):
|
||||
return x/warmup
|
||||
return 1.0
|
||||
|
||||
class Warmup_Linear_with_Warning(object):
|
||||
def __init__(self, **kw):
|
||||
super(Warmup_Linear_with_Warning, self).__init__()
|
||||
self.warned_at_x = -1
|
||||
|
||||
def __call__(self, x, warmup=0.002):
|
||||
if x > 1 and x > self.warned_at_x:
|
||||
logger.warning("Training beyond specified 't_total' steps. Learning rate set to zero. Please set 't_total' of BertAdam correctly.")
|
||||
self.warned_at_x = x
|
||||
return warmup_linear(x, warmup=warmup)
|
||||
|
||||
def warmup_linear(x, warmup=0.002):
|
||||
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
|
||||
After `t_total`-th training step, learning rate is zero. """
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
if x > 1:
|
||||
logger.warning("Training beyond specified 't_total' steps. Learning rate set to zero. Please set 't_total' of BertAdam correctly.")
|
||||
return max((x-1.)/(warmup-1.), 0)
|
||||
|
||||
SCHEDULES = {
|
||||
'warmup_cosine':warmup_cosine,
|
||||
'warmup_constant':warmup_constant,
|
||||
'warmup_linear':warmup_linear,
|
||||
'warmup_linear': Warmup_Linear_with_Warning(), #warmup_linear,
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user