added warning
This commit is contained in:
@@ -35,17 +35,6 @@ def warmup_constant(x, warmup=0.002):
|
|||||||
return x/warmup
|
return x/warmup
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
class Warmup_Linear_with_Warning(object):
|
|
||||||
def __init__(self, **kw):
|
|
||||||
super(Warmup_Linear_with_Warning, self).__init__()
|
|
||||||
self.warned_at_x = -1
|
|
||||||
|
|
||||||
def __call__(self, x, warmup=0.002):
|
|
||||||
if x > 1 and x > self.warned_at_x:
|
|
||||||
logger.warning("Training beyond specified 't_total' steps. Learning rate set to zero. Please set 't_total' of BertAdam correctly.")
|
|
||||||
self.warned_at_x = x
|
|
||||||
return warmup_linear(x, warmup=warmup)
|
|
||||||
|
|
||||||
def warmup_linear(x, warmup=0.002):
|
def warmup_linear(x, warmup=0.002):
|
||||||
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
|
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
|
||||||
After `t_total`-th training step, learning rate is zero. """
|
After `t_total`-th training step, learning rate is zero. """
|
||||||
@@ -54,9 +43,9 @@ def warmup_linear(x, warmup=0.002):
|
|||||||
return max((x-1.)/(warmup-1.), 0)
|
return max((x-1.)/(warmup-1.), 0)
|
||||||
|
|
||||||
SCHEDULES = {
|
SCHEDULES = {
|
||||||
'warmup_cosine':warmup_cosine,
|
'warmup_cosine': warmup_cosine,
|
||||||
'warmup_constant':warmup_constant,
|
'warmup_constant': warmup_constant,
|
||||||
'warmup_linear': Warmup_Linear_with_Warning(), #warmup_linear,
|
'warmup_linear': warmup_linear,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -93,6 +82,8 @@ class BertAdam(Optimizer):
|
|||||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
||||||
max_grad_norm=max_grad_norm)
|
max_grad_norm=max_grad_norm)
|
||||||
super(BertAdam, self).__init__(params, defaults)
|
super(BertAdam, self).__init__(params, defaults)
|
||||||
|
# warning for t_total exceeded
|
||||||
|
self._warned_for_t_total_at_progress = -1 if schedule == "warmup_linear" else float("inf")
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
lr = []
|
lr = []
|
||||||
@@ -163,7 +154,15 @@ class BertAdam(Optimizer):
|
|||||||
|
|
||||||
if group['t_total'] != -1:
|
if group['t_total'] != -1:
|
||||||
schedule_fct = SCHEDULES[group['schedule']]
|
schedule_fct = SCHEDULES[group['schedule']]
|
||||||
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
|
# warning for exceeding t_total (only active with warmup_linear
|
||||||
|
progress = state['step']/group['t_total']
|
||||||
|
if progress > 1. and progress > self._warned_for_t_total_at_progress:
|
||||||
|
logger.warning(
|
||||||
|
"Training beyond specified 't_total' steps. Learning rate set to zero. "
|
||||||
|
"Please set 't_total' of {} correctly.".format(self.__class__.__name__))
|
||||||
|
self._warned_for_t_total_at_progress = progress
|
||||||
|
# end warning
|
||||||
|
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
|
||||||
else:
|
else:
|
||||||
lr_scheduled = group['lr']
|
lr_scheduled = group['lr']
|
||||||
|
|
||||||
|
|||||||
@@ -40,8 +40,6 @@ def warmup_linear(x, warmup=0.002):
|
|||||||
After `t_total`-th training step, learning rate is zero. """
|
After `t_total`-th training step, learning rate is zero. """
|
||||||
if x < warmup:
|
if x < warmup:
|
||||||
return x/warmup
|
return x/warmup
|
||||||
if x > 1:
|
|
||||||
logger.warning("Training beyond specified 't_total' steps. Learning rate set to zero. Please set 't_total' of BertAdam correctly.")
|
|
||||||
return max((x-1.)/(warmup-1.), 0)
|
return max((x-1.)/(warmup-1.), 0)
|
||||||
|
|
||||||
SCHEDULES = {
|
SCHEDULES = {
|
||||||
@@ -73,6 +71,8 @@ class OpenAIAdam(Optimizer):
|
|||||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
|
b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
|
||||||
max_grad_norm=max_grad_norm)
|
max_grad_norm=max_grad_norm)
|
||||||
super(OpenAIAdam, self).__init__(params, defaults)
|
super(OpenAIAdam, self).__init__(params, defaults)
|
||||||
|
# warning for t_total exceeded
|
||||||
|
self._warned_for_t_total_at_progress = -1 if schedule == "warmup_linear" else float("inf")
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
lr = []
|
lr = []
|
||||||
@@ -137,7 +137,15 @@ class OpenAIAdam(Optimizer):
|
|||||||
|
|
||||||
if group['t_total'] != -1:
|
if group['t_total'] != -1:
|
||||||
schedule_fct = SCHEDULES[group['schedule']]
|
schedule_fct = SCHEDULES[group['schedule']]
|
||||||
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
|
# warning for exceeding t_total (only active with warmup_linear
|
||||||
|
progress = state['step']/group['t_total']
|
||||||
|
if progress > 1. and progress > self._warned_for_t_total_at_progress:
|
||||||
|
logger.warning(
|
||||||
|
"Training beyond specified 't_total' steps. Learning rate set to zero. "
|
||||||
|
"Please set 't_total' of {} correctly.".format(self.__class__.__name__))
|
||||||
|
self._warned_for_t_total_at_progress = progress
|
||||||
|
# end warning
|
||||||
|
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
|
||||||
else:
|
else:
|
||||||
lr_scheduled = group['lr']
|
lr_scheduled = group['lr']
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user