- updated docs for optimization
This commit is contained in:
@@ -20,7 +20,8 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_pretrained_bert import BertAdam, WarmupCosineWithRestartsSchedule
|
from pytorch_pretrained_bert import BertAdam
|
||||||
|
from pytorch_pretrained_bert.optimization import WarmupCosineWithWarmupRestartsSchedule
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -50,7 +51,7 @@ class OptimizationTest(unittest.TestCase):
|
|||||||
|
|
||||||
class WarmupCosineWithRestartsTest(unittest.TestCase):
|
class WarmupCosineWithRestartsTest(unittest.TestCase):
|
||||||
def test_it(self):
|
def test_it(self):
|
||||||
m = WarmupCosineWithRestartsSchedule(warmup=0.2, t_total=1, cycles=3)
|
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1, cycles=3)
|
||||||
x = np.arange(0, 1000) / 1000
|
x = np.arange(0, 1000) / 1000
|
||||||
y = [m.get_lr_(xe) for xe in x]
|
y = [m.get_lr_(xe) for xe in x]
|
||||||
plt.plot(y)
|
plt.plot(y)
|
||||||
|
|||||||
Reference in New Issue
Block a user