added warning
This commit is contained in:
@@ -35,19 +35,28 @@ def warmup_constant(x, warmup=0.002):
|
|||||||
return x/warmup
|
return x/warmup
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
|
class Warmup_Linear_with_Warning(object):
|
||||||
|
def __init__(self, **kw):
|
||||||
|
super(Warmup_Linear_with_Warning, self).__init__()
|
||||||
|
self.warned_at_x = -1
|
||||||
|
|
||||||
|
def __call__(self, x, warmup=0.002):
|
||||||
|
if x > 1 and x > self.warned_at_x:
|
||||||
|
logger.warning("Training beyond specified 't_total' steps. Learning rate set to zero. Please set 't_total' of BertAdam correctly.")
|
||||||
|
self.warned_at_x = x
|
||||||
|
return warmup_linear(x, warmup=warmup)
|
||||||
|
|
||||||
def warmup_linear(x, warmup=0.002):
|
def warmup_linear(x, warmup=0.002):
|
||||||
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
|
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
|
||||||
After `t_total`-th training step, learning rate is zero. """
|
After `t_total`-th training step, learning rate is zero. """
|
||||||
if x < warmup:
|
if x < warmup:
|
||||||
return x/warmup
|
return x/warmup
|
||||||
if x > 1:
|
|
||||||
logger.warning("Training beyond specified 't_total' steps. Learning rate set to zero. Please set 't_total' of BertAdam correctly.")
|
|
||||||
return max((x-1.)/(warmup-1.), 0)
|
return max((x-1.)/(warmup-1.), 0)
|
||||||
|
|
||||||
SCHEDULES = {
|
SCHEDULES = {
|
||||||
'warmup_cosine':warmup_cosine,
|
'warmup_cosine':warmup_cosine,
|
||||||
'warmup_constant':warmup_constant,
|
'warmup_constant':warmup_constant,
|
||||||
'warmup_linear':warmup_linear,
|
'warmup_linear': Warmup_Linear_with_Warning(), #warmup_linear,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user