class weights
This commit is contained in:
@@ -24,7 +24,8 @@ import logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam", "WarmupCosineWithRestartsSchedule"]
|
__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam",
|
||||||
|
"WarmupMultiCosineSchedule", "WarmupCosineWithRestartsSchedule"]
|
||||||
|
|
||||||
|
|
||||||
class LRSchedule(object):
|
class LRSchedule(object):
|
||||||
@@ -72,10 +73,11 @@ class WarmupCosineSchedule(LRSchedule):
|
|||||||
return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
|
return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
|
||||||
|
|
||||||
|
|
||||||
class WarmupCosineWithRestartsSchedule(WarmupCosineSchedule):
|
class WarmupMultiCosineSchedule(WarmupCosineSchedule):
|
||||||
warn_t_total = True
|
warn_t_total = True
|
||||||
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
|
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
|
||||||
super(WarmupCosineWithRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
|
super(WarmupMultiCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
|
||||||
|
assert(cycles >= 1.)
|
||||||
|
|
||||||
def get_lr_(self, progress):
|
def get_lr_(self, progress):
|
||||||
if self.t_total <= 0:
|
if self.t_total <= 0:
|
||||||
@@ -88,6 +90,19 @@ class WarmupCosineWithRestartsSchedule(WarmupCosineSchedule):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class WarmupCosineWithRestartsSchedule(WarmupMultiCosineSchedule):
|
||||||
|
def get_lr_(self, progress):
|
||||||
|
if self.t_total <= 0.:
|
||||||
|
return 1.
|
||||||
|
progress = progress * self.cycles % 1.
|
||||||
|
if progress < self.warmup:
|
||||||
|
return progress / self.warmup
|
||||||
|
else:
|
||||||
|
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
||||||
|
ret = 0.5 * (1. + math.cos(math.pi * progress))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class WarmupConstantSchedule(LRSchedule):
|
class WarmupConstantSchedule(LRSchedule):
|
||||||
warn_t_total = False
|
warn_t_total = False
|
||||||
def get_lr_(self, progress):
|
def get_lr_(self, progress):
|
||||||
|
|||||||
@@ -20,7 +20,9 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_pretrained_bert import BertAdam
|
from pytorch_pretrained_bert import BertAdam, WarmupCosineWithRestartsSchedule
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
class OptimizationTest(unittest.TestCase):
|
class OptimizationTest(unittest.TestCase):
|
||||||
|
|
||||||
@@ -46,5 +48,16 @@ class OptimizationTest(unittest.TestCase):
|
|||||||
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
|
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
class WarmupCosineWithRestartsTest(unittest.TestCase):
|
||||||
|
def test_it(self):
|
||||||
|
m = WarmupCosineWithRestartsSchedule(warmup=0.2, t_total=1, cycles=3)
|
||||||
|
x = np.arange(0, 1000) / 1000
|
||||||
|
y = [m.get_lr_(xe) for xe in x]
|
||||||
|
plt.plot(y)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user