From 9bc3773c84846ebcecdc237f73158e14dae845c4 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Wed, 27 Feb 2019 16:10:31 +0100 Subject: [PATCH] added warning --- pytorch_pretrained_bert/optimization.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index 6681410feb..c72c552a23 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -35,19 +35,28 @@ def warmup_constant(x, warmup=0.002): return x/warmup return 1.0 +class Warmup_Linear_with_Warning(object): + def __init__(self, **kw): + super(Warmup_Linear_with_Warning, self).__init__() + self.warned_at_x = -1 + + def __call__(self, x, warmup=0.002): + if x > 1 and x > self.warned_at_x: + logger.warning("Training beyond specified 't_total' steps. Learning rate set to zero. Please set 't_total' of BertAdam correctly.") + self.warned_at_x = x + return warmup_linear(x, warmup=warmup) + def warmup_linear(x, warmup=0.002): """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. After `t_total`-th training step, learning rate is zero. """ if x < warmup: return x/warmup - if x > 1: - logger.warning("Training beyond specified 't_total' steps. Learning rate set to zero. Please set 't_total' of BertAdam correctly.") return max((x-1.)/(warmup-1.), 0) SCHEDULES = { 'warmup_cosine':warmup_cosine, 'warmup_constant':warmup_constant, - 'warmup_linear':warmup_linear, + 'warmup_linear': Warmup_Linear_with_Warning(), #warmup_linear, }