From 262a9992d7ab348dfc35bda6c550fbbba8f5bc42 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Mon, 18 Mar 2019 18:29:12 +0100 Subject: [PATCH] class weights --- pytorch_pretrained_bert/optimization.py | 21 ++++++++++++++++++--- tests/optimization_test.py | 15 ++++++++++++++- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index eb24c3bd37..a39a18cea3 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -24,7 +24,8 @@ import logging logger = logging.getLogger(__name__) -__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam", "WarmupCosineWithRestartsSchedule"] +__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam", + "WarmupMultiCosineSchedule", "WarmupCosineWithRestartsSchedule"] class LRSchedule(object): @@ -72,10 +73,11 @@ class WarmupCosineSchedule(LRSchedule): return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) -class WarmupCosineWithRestartsSchedule(WarmupCosineSchedule): +class WarmupMultiCosineSchedule(WarmupCosineSchedule): warn_t_total = True def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): - super(WarmupCosineWithRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) + super(WarmupMultiCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) + assert(cycles >= 1.) def get_lr_(self, progress): if self.t_total <= 0: @@ -88,6 +90,19 @@ class WarmupCosineWithRestartsSchedule(WarmupCosineSchedule): return ret +class WarmupCosineWithRestartsSchedule(WarmupMultiCosineSchedule): + def get_lr_(self, progress): + if self.t_total <= 0.: + return 1. + progress = progress * self.cycles % 1. + if progress < self.warmup: + return progress / self.warmup + else: + progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup + ret = 0.5 * (1. + math.cos(math.pi * progress)) + return ret + + class WarmupConstantSchedule(LRSchedule): warn_t_total = False def get_lr_(self, progress): diff --git a/tests/optimization_test.py b/tests/optimization_test.py index 848b9d1cf5..3f9f8abbfe 100644 --- a/tests/optimization_test.py +++ b/tests/optimization_test.py @@ -20,7 +20,9 @@ import unittest import torch -from pytorch_pretrained_bert import BertAdam +from pytorch_pretrained_bert import BertAdam, WarmupCosineWithRestartsSchedule +from matplotlib import pyplot as plt +import numpy as np class OptimizationTest(unittest.TestCase): @@ -46,5 +48,16 @@ class OptimizationTest(unittest.TestCase): self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) +class WarmupCosineWithRestartsTest(unittest.TestCase): + def test_it(self): + m = WarmupCosineWithRestartsSchedule(warmup=0.2, t_total=1, cycles=3) + x = np.arange(0, 1000) / 1000 + y = [m.get_lr_(xe) for xe in x] + plt.plot(y) + plt.show() + + + + if __name__ == "__main__": unittest.main()