From 4d79e0d386ee35deb232c74a84922809a3317c8c Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Wed, 27 Feb 2019 16:50:05 +0100 Subject: [PATCH] added warning --- pytorch_pretrained_bert/optimization.py | 8 ++++---- pytorch_pretrained_bert/optimization_openai.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index a6248374aa..213780ff46 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -82,8 +82,6 @@ class BertAdam(Optimizer): b1=b1, b2=b2, e=e, weight_decay=weight_decay, max_grad_norm=max_grad_norm) super(BertAdam, self).__init__(params, defaults) - # warning for t_total exceeded - self._warned_for_t_total_at_progress = -1 if schedule == "warmup_linear" else float("inf") # warning is not active with other schedules (since it doesn't break them) def get_lr(self): lr = [] @@ -111,6 +109,8 @@ class BertAdam(Optimizer): if closure is not None: loss = closure() + warned_for_t_total = False + for group in self.param_groups: for p in group['params']: if p.grad is None: @@ -157,11 +157,11 @@ class BertAdam(Optimizer): progress = state['step']/group['t_total'] lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) # warning for exceeding t_total (only active with warmup_linear - if progress > 1. and progress > self._warned_for_t_total_at_progress: + if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total: logger.warning( "Training beyond specified 't_total' steps. Learning rate set to {}. " "Please set 't_total' of {} correctly.".format(lr_scheduled, self.__class__.__name__)) - self._warned_for_t_total_at_progress = progress + warned_for_t_total = True # end warning else: lr_scheduled = group['lr'] diff --git a/pytorch_pretrained_bert/optimization_openai.py b/pytorch_pretrained_bert/optimization_openai.py index 80d496adf0..57dfe0f3cc 100644 --- a/pytorch_pretrained_bert/optimization_openai.py +++ b/pytorch_pretrained_bert/optimization_openai.py @@ -71,8 +71,6 @@ class OpenAIAdam(Optimizer): b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2, max_grad_norm=max_grad_norm) super(OpenAIAdam, self).__init__(params, defaults) - # warning for t_total exceeded - self._warned_for_t_total_at_progress = -1 if schedule == "warmup_linear" else float("inf") # warning is not active with other schedules (since it doesn't break them) def get_lr(self): lr = [] @@ -100,6 +98,8 @@ class OpenAIAdam(Optimizer): if closure is not None: loss = closure() + warned_for_t_total = False + for group in self.param_groups: for p in group['params']: if p.grad is None: @@ -140,11 +140,11 @@ class OpenAIAdam(Optimizer): progress = state['step']/group['t_total'] lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) # warning for exceeding t_total (only active with warmup_linear - if progress > 1. and progress > self._warned_for_t_total_at_progress: + if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total: logger.warning( "Training beyond specified 't_total' steps. Learning rate set to {}. " "Please set 't_total' of {} correctly.".format(lr_scheduled, self.__class__.__name__)) - self._warned_for_t_total_at_progress = progress + warned_for_t_total = True # end warning else: lr_scheduled = group['lr']