From d164867d90c7b352445aa7d4028a6ba156a70a77 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Wed, 3 Apr 2019 16:13:51 +0200 Subject: [PATCH] - updated docs for optimization --- tests/optimization_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/optimization_test.py b/tests/optimization_test.py index 3f9f8abbfe..8c28ad38ad 100644 --- a/tests/optimization_test.py +++ b/tests/optimization_test.py @@ -20,7 +20,8 @@ import unittest import torch -from pytorch_pretrained_bert import BertAdam, WarmupCosineWithRestartsSchedule +from pytorch_pretrained_bert import BertAdam +from pytorch_pretrained_bert.optimization import WarmupCosineWithWarmupRestartsSchedule from matplotlib import pyplot as plt import numpy as np @@ -50,7 +51,7 @@ class OptimizationTest(unittest.TestCase): class WarmupCosineWithRestartsTest(unittest.TestCase): def test_it(self): - m = WarmupCosineWithRestartsSchedule(warmup=0.2, t_total=1, cycles=3) + m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1, cycles=3) x = np.arange(0, 1000) / 1000 y = [m.get_lr_(xe) for xe in x] plt.plot(y)