added warning
This commit is contained in:
@@ -154,15 +154,15 @@ class BertAdam(Optimizer):
|
|||||||
|
|
||||||
if group['t_total'] != -1:
|
if group['t_total'] != -1:
|
||||||
schedule_fct = SCHEDULES[group['schedule']]
|
schedule_fct = SCHEDULES[group['schedule']]
|
||||||
# warning for exceeding t_total (only active with warmup_linear
|
|
||||||
progress = state['step']/group['t_total']
|
progress = state['step']/group['t_total']
|
||||||
|
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
|
||||||
|
# warning for exceeding t_total (only active with warmup_linear
|
||||||
if progress > 1. and progress > self._warned_for_t_total_at_progress:
|
if progress > 1. and progress > self._warned_for_t_total_at_progress:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Training beyond specified 't_total' steps. Learning rate set to zero. "
|
"Training beyond specified 't_total' steps. Learning rate set to {}. "
|
||||||
"Please set 't_total' of {} correctly.".format(self.__class__.__name__))
|
"Please set 't_total' of {} correctly.".format(lr_scheduled, self.__class__.__name__))
|
||||||
self._warned_for_t_total_at_progress = progress
|
self._warned_for_t_total_at_progress = progress
|
||||||
# end warning
|
# end warning
|
||||||
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
|
|
||||||
else:
|
else:
|
||||||
lr_scheduled = group['lr']
|
lr_scheduled = group['lr']
|
||||||
|
|
||||||
|
|||||||
@@ -137,15 +137,15 @@ class OpenAIAdam(Optimizer):
|
|||||||
|
|
||||||
if group['t_total'] != -1:
|
if group['t_total'] != -1:
|
||||||
schedule_fct = SCHEDULES[group['schedule']]
|
schedule_fct = SCHEDULES[group['schedule']]
|
||||||
# warning for exceeding t_total (only active with warmup_linear
|
|
||||||
progress = state['step']/group['t_total']
|
progress = state['step']/group['t_total']
|
||||||
|
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
|
||||||
|
# warning for exceeding t_total (only active with warmup_linear
|
||||||
if progress > 1. and progress > self._warned_for_t_total_at_progress:
|
if progress > 1. and progress > self._warned_for_t_total_at_progress:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Training beyond specified 't_total' steps. Learning rate set to zero. "
|
"Training beyond specified 't_total' steps. Learning rate set to {}. "
|
||||||
"Please set 't_total' of {} correctly.".format(self.__class__.__name__))
|
"Please set 't_total' of {} correctly.".format(lr_scheduled, self.__class__.__name__))
|
||||||
self._warned_for_t_total_at_progress = progress
|
self._warned_for_t_total_at_progress = progress
|
||||||
# end warning
|
# end warning
|
||||||
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
|
|
||||||
else:
|
else:
|
||||||
lr_scheduled = group['lr']
|
lr_scheduled = group['lr']
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user