class weights
This commit is contained in:
@@ -20,7 +20,9 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_pretrained_bert import BertAdam
|
||||
from pytorch_pretrained_bert import BertAdam, WarmupCosineWithRestartsSchedule
|
||||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
class OptimizationTest(unittest.TestCase):
|
||||
|
||||
@@ -46,5 +48,16 @@ class OptimizationTest(unittest.TestCase):
|
||||
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
|
||||
|
||||
|
||||
class WarmupCosineWithRestartsTest(unittest.TestCase):
|
||||
def test_it(self):
|
||||
m = WarmupCosineWithRestartsSchedule(warmup=0.2, t_total=1, cycles=3)
|
||||
x = np.arange(0, 1000) / 1000
|
||||
y = [m.get_lr_(xe) for xe in x]
|
||||
plt.plot(y)
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user