diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index c510109cb4..7d426f23ba 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -154,15 +154,15 @@ class BertAdam(Optimizer): if group['t_total'] != -1: schedule_fct = SCHEDULES[group['schedule']] - # warning for exceeding t_total (only active with warmup_linear progress = state['step']/group['t_total'] + lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) + # warning for exceeding t_total (only active with warmup_linear if progress > 1. and progress > self._warned_for_t_total_at_progress: logger.warning( - "Training beyond specified 't_total' steps. Learning rate set to zero. " - "Please set 't_total' of {} correctly.".format(self.__class__.__name__)) + "Training beyond specified 't_total' steps. Learning rate set to {}. " + "Please set 't_total' of {} correctly.".format(lr_scheduled, self.__class__.__name__)) self._warned_for_t_total_at_progress = progress # end warning - lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) else: lr_scheduled = group['lr'] diff --git a/pytorch_pretrained_bert/optimization_openai.py b/pytorch_pretrained_bert/optimization_openai.py index e76e84ac05..c609d5c242 100644 --- a/pytorch_pretrained_bert/optimization_openai.py +++ b/pytorch_pretrained_bert/optimization_openai.py @@ -137,15 +137,15 @@ class OpenAIAdam(Optimizer): if group['t_total'] != -1: schedule_fct = SCHEDULES[group['schedule']] - # warning for exceeding t_total (only active with warmup_linear progress = state['step']/group['t_total'] + lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) + # warning for exceeding t_total (only active with warmup_linear if progress > 1. and progress > self._warned_for_t_total_at_progress: logger.warning( - "Training beyond specified 't_total' steps. Learning rate set to zero. " - "Please set 't_total' of {} correctly.".format(self.__class__.__name__)) + "Training beyond specified 't_total' steps. Learning rate set to {}. " + "Please set 't_total' of {} correctly.".format(lr_scheduled, self.__class__.__name__)) self._warned_for_t_total_at_progress = progress # end warning - lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) else: lr_scheduled = group['lr']