added warning
This commit is contained in:
@@ -82,8 +82,6 @@ class BertAdam(Optimizer):
|
|||||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
||||||
max_grad_norm=max_grad_norm)
|
max_grad_norm=max_grad_norm)
|
||||||
super(BertAdam, self).__init__(params, defaults)
|
super(BertAdam, self).__init__(params, defaults)
|
||||||
# warning for t_total exceeded
|
|
||||||
self._warned_for_t_total_at_progress = -1 if schedule == "warmup_linear" else float("inf") # warning is not active with other schedules (since it doesn't break them)
|
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
lr = []
|
lr = []
|
||||||
@@ -111,6 +109,8 @@ class BertAdam(Optimizer):
|
|||||||
if closure is not None:
|
if closure is not None:
|
||||||
loss = closure()
|
loss = closure()
|
||||||
|
|
||||||
|
warned_for_t_total = False
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
@@ -157,11 +157,11 @@ class BertAdam(Optimizer):
|
|||||||
progress = state['step']/group['t_total']
|
progress = state['step']/group['t_total']
|
||||||
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
|
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
|
||||||
# warning for exceeding t_total (only active with warmup_linear
|
# warning for exceeding t_total (only active with warmup_linear
|
||||||
if progress > 1. and progress > self._warned_for_t_total_at_progress:
|
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Training beyond specified 't_total' steps. Learning rate set to {}. "
|
"Training beyond specified 't_total' steps. Learning rate set to {}. "
|
||||||
"Please set 't_total' of {} correctly.".format(lr_scheduled, self.__class__.__name__))
|
"Please set 't_total' of {} correctly.".format(lr_scheduled, self.__class__.__name__))
|
||||||
self._warned_for_t_total_at_progress = progress
|
warned_for_t_total = True
|
||||||
# end warning
|
# end warning
|
||||||
else:
|
else:
|
||||||
lr_scheduled = group['lr']
|
lr_scheduled = group['lr']
|
||||||
|
|||||||
@@ -71,8 +71,6 @@ class OpenAIAdam(Optimizer):
|
|||||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
|
b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
|
||||||
max_grad_norm=max_grad_norm)
|
max_grad_norm=max_grad_norm)
|
||||||
super(OpenAIAdam, self).__init__(params, defaults)
|
super(OpenAIAdam, self).__init__(params, defaults)
|
||||||
# warning for t_total exceeded
|
|
||||||
self._warned_for_t_total_at_progress = -1 if schedule == "warmup_linear" else float("inf") # warning is not active with other schedules (since it doesn't break them)
|
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
lr = []
|
lr = []
|
||||||
@@ -100,6 +98,8 @@ class OpenAIAdam(Optimizer):
|
|||||||
if closure is not None:
|
if closure is not None:
|
||||||
loss = closure()
|
loss = closure()
|
||||||
|
|
||||||
|
warned_for_t_total = False
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
@@ -140,11 +140,11 @@ class OpenAIAdam(Optimizer):
|
|||||||
progress = state['step']/group['t_total']
|
progress = state['step']/group['t_total']
|
||||||
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
|
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
|
||||||
# warning for exceeding t_total (only active with warmup_linear
|
# warning for exceeding t_total (only active with warmup_linear
|
||||||
if progress > 1. and progress > self._warned_for_t_total_at_progress:
|
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Training beyond specified 't_total' steps. Learning rate set to {}. "
|
"Training beyond specified 't_total' steps. Learning rate set to {}. "
|
||||||
"Please set 't_total' of {} correctly.".format(lr_scheduled, self.__class__.__name__))
|
"Please set 't_total' of {} correctly.".format(lr_scheduled, self.__class__.__name__))
|
||||||
self._warned_for_t_total_at_progress = progress
|
warned_for_t_total = True
|
||||||
# end warning
|
# end warning
|
||||||
else:
|
else:
|
||||||
lr_scheduled = group['lr']
|
lr_scheduled = group['lr']
|
||||||
|
|||||||
Reference in New Issue
Block a user