class weights

This commit is contained in:
lukovnikov
2019-03-18 18:29:12 +01:00
parent b6c1cae67b
commit 262a9992d7
2 changed files with 32 additions and 4 deletions

View File

@@ -20,7 +20,9 @@ import unittest
import torch
from pytorch_pretrained_bert import BertAdam
from pytorch_pretrained_bert import BertAdam, WarmupCosineWithRestartsSchedule
from matplotlib import pyplot as plt
import numpy as np
class OptimizationTest(unittest.TestCase):
@@ -46,5 +48,16 @@ class OptimizationTest(unittest.TestCase):
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
class WarmupCosineWithRestartsTest(unittest.TestCase):
def test_it(self):
m = WarmupCosineWithRestartsSchedule(warmup=0.2, t_total=1, cycles=3)
x = np.arange(0, 1000) / 1000
y = [m.get_lr_(xe) for xe in x]
plt.plot(y)
plt.show()
if __name__ == "__main__":
unittest.main()