diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index c72c552a23..c510109cb4 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -35,17 +35,6 @@ def warmup_constant(x, warmup=0.002): return x/warmup return 1.0 -class Warmup_Linear_with_Warning(object): - def __init__(self, **kw): - super(Warmup_Linear_with_Warning, self).__init__() - self.warned_at_x = -1 - - def __call__(self, x, warmup=0.002): - if x > 1 and x > self.warned_at_x: - logger.warning("Training beyond specified 't_total' steps. Learning rate set to zero. Please set 't_total' of BertAdam correctly.") - self.warned_at_x = x - return warmup_linear(x, warmup=warmup) - def warmup_linear(x, warmup=0.002): """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. After `t_total`-th training step, learning rate is zero. """ @@ -54,9 +43,9 @@ def warmup_linear(x, warmup=0.002): return max((x-1.)/(warmup-1.), 0) SCHEDULES = { - 'warmup_cosine':warmup_cosine, - 'warmup_constant':warmup_constant, - 'warmup_linear': Warmup_Linear_with_Warning(), #warmup_linear, + 'warmup_cosine': warmup_cosine, + 'warmup_constant': warmup_constant, + 'warmup_linear': warmup_linear, } @@ -93,6 +82,8 @@ class BertAdam(Optimizer): b1=b1, b2=b2, e=e, weight_decay=weight_decay, max_grad_norm=max_grad_norm) super(BertAdam, self).__init__(params, defaults) + # warning for t_total exceeded + self._warned_for_t_total_at_progress = -1 if schedule == "warmup_linear" else float("inf") def get_lr(self): lr = [] @@ -163,7 +154,15 @@ class BertAdam(Optimizer): if group['t_total'] != -1: schedule_fct = SCHEDULES[group['schedule']] - lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) + # warning for exceeding t_total (only active with warmup_linear + progress = state['step']/group['t_total'] + if progress > 1. and progress > self._warned_for_t_total_at_progress: + logger.warning( + "Training beyond specified 't_total' steps. Learning rate set to zero. " + "Please set 't_total' of {} correctly.".format(self.__class__.__name__)) + self._warned_for_t_total_at_progress = progress + # end warning + lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) else: lr_scheduled = group['lr'] diff --git a/pytorch_pretrained_bert/optimization_openai.py b/pytorch_pretrained_bert/optimization_openai.py index 7df6369023..e76e84ac05 100644 --- a/pytorch_pretrained_bert/optimization_openai.py +++ b/pytorch_pretrained_bert/optimization_openai.py @@ -40,8 +40,6 @@ def warmup_linear(x, warmup=0.002): After `t_total`-th training step, learning rate is zero. """ if x < warmup: return x/warmup - if x > 1: - logger.warning("Training beyond specified 't_total' steps. Learning rate set to zero. Please set 't_total' of BertAdam correctly.") return max((x-1.)/(warmup-1.), 0) SCHEDULES = { @@ -73,6 +71,8 @@ class OpenAIAdam(Optimizer): b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2, max_grad_norm=max_grad_norm) super(OpenAIAdam, self).__init__(params, defaults) + # warning for t_total exceeded + self._warned_for_t_total_at_progress = -1 if schedule == "warmup_linear" else float("inf") def get_lr(self): lr = [] @@ -137,7 +137,15 @@ class OpenAIAdam(Optimizer): if group['t_total'] != -1: schedule_fct = SCHEDULES[group['schedule']] - lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) + # warning for exceeding t_total (only active with warmup_linear + progress = state['step']/group['t_total'] + if progress > 1. and progress > self._warned_for_t_total_at_progress: + logger.warning( + "Training beyond specified 't_total' steps. Learning rate set to zero. " + "Please set 't_total' of {} correctly.".format(self.__class__.__name__)) + self._warned_for_t_total_at_progress = progress + # end warning + lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) else: lr_scheduled = group['lr']