- removed __all__ in optimization
- removed unused plotting code - using ABC for LRSchedule - added some schedule object init tests
This commit is contained in:
@@ -20,15 +20,12 @@ from torch.optim import Optimizer
|
|||||||
from torch.optim.optimizer import required
|
from torch.optim.optimizer import required
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
import logging
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam",
|
class _LRSchedule(ABC):
|
||||||
"WarmupCosineWithHardRestartsSchedule", "WarmupCosineWithWarmupRestartsSchedule", "SCHEDULES"]
|
|
||||||
|
|
||||||
|
|
||||||
class LRSchedule(object):
|
|
||||||
""" Parent of all LRSchedules here. """
|
""" Parent of all LRSchedules here. """
|
||||||
warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
|
warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
|
||||||
def __init__(self, warmup=0.002, t_total=-1, **kw):
|
def __init__(self, warmup=0.002, t_total=-1, **kw):
|
||||||
@@ -37,7 +34,7 @@ class LRSchedule(object):
|
|||||||
:param t_total: how many training steps (updates) are planned
|
:param t_total: how many training steps (updates) are planned
|
||||||
:param kw:
|
:param kw:
|
||||||
"""
|
"""
|
||||||
super(LRSchedule, self).__init__(**kw)
|
super(_LRSchedule, self).__init__(**kw)
|
||||||
if t_total < 0:
|
if t_total < 0:
|
||||||
logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
|
logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
|
||||||
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
||||||
@@ -65,16 +62,21 @@ class LRSchedule(object):
|
|||||||
# end warning
|
# end warning
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def get_lr_(self, progress):
|
def get_lr_(self, progress):
|
||||||
"""
|
"""
|
||||||
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
|
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
|
||||||
:return: learning rate multiplier for current update
|
:return: learning rate multiplier for current update
|
||||||
"""
|
"""
|
||||||
return 1.
|
return 1.
|
||||||
# raise NotImplemented("use subclass") -
|
|
||||||
|
|
||||||
|
|
||||||
class WarmupCosineSchedule(LRSchedule):
|
class ConstantLR(_LRSchedule):
|
||||||
|
def get_lr_(self, progress):
|
||||||
|
return 1.
|
||||||
|
|
||||||
|
|
||||||
|
class WarmupCosineSchedule(_LRSchedule):
|
||||||
"""
|
"""
|
||||||
Cosine learning rate schedule with linear warmup. Cosine after warmup is without restarts.
|
Cosine learning rate schedule with linear warmup. Cosine after warmup is without restarts.
|
||||||
"""
|
"""
|
||||||
@@ -135,7 +137,7 @@ class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedul
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class WarmupConstantSchedule(LRSchedule):
|
class WarmupConstantSchedule(_LRSchedule):
|
||||||
"""
|
"""
|
||||||
Applies linear warmup. After warmup always returns 1..
|
Applies linear warmup. After warmup always returns 1..
|
||||||
"""
|
"""
|
||||||
@@ -145,7 +147,7 @@ class WarmupConstantSchedule(LRSchedule):
|
|||||||
return 1.
|
return 1.
|
||||||
|
|
||||||
|
|
||||||
class WarmupLinearSchedule(LRSchedule):
|
class WarmupLinearSchedule(_LRSchedule):
|
||||||
"""
|
"""
|
||||||
Linear warmup. Linear decay after warmup.
|
Linear warmup. Linear decay after warmup.
|
||||||
"""
|
"""
|
||||||
@@ -157,8 +159,8 @@ class WarmupLinearSchedule(LRSchedule):
|
|||||||
|
|
||||||
|
|
||||||
SCHEDULES = {
|
SCHEDULES = {
|
||||||
None: LRSchedule,
|
None: ConstantLR,
|
||||||
"none": LRSchedule,
|
"none": ConstantLR,
|
||||||
"warmup_cosine": WarmupCosineSchedule,
|
"warmup_cosine": WarmupCosineSchedule,
|
||||||
"warmup_constant": WarmupConstantSchedule,
|
"warmup_constant": WarmupConstantSchedule,
|
||||||
"warmup_linear": WarmupLinearSchedule
|
"warmup_linear": WarmupLinearSchedule
|
||||||
@@ -185,7 +187,7 @@ class BertAdam(Optimizer):
|
|||||||
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
|
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
|
||||||
if lr is not required and lr < 0.0:
|
if lr is not required and lr < 0.0:
|
||||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||||
if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES:
|
if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
|
||||||
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
||||||
if not 0.0 <= b1 < 1.0:
|
if not 0.0 <= b1 < 1.0:
|
||||||
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
||||||
@@ -194,7 +196,7 @@ class BertAdam(Optimizer):
|
|||||||
if not e >= 0.0:
|
if not e >= 0.0:
|
||||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
||||||
# initialize schedule object
|
# initialize schedule object
|
||||||
if not isinstance(schedule, LRSchedule):
|
if not isinstance(schedule, _LRSchedule):
|
||||||
schedule_type = SCHEDULES[schedule]
|
schedule_type = SCHEDULES[schedule]
|
||||||
schedule = schedule_type(warmup=warmup, t_total=t_total)
|
schedule = schedule_type(warmup=warmup, t_total=t_total)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -20,7 +20,8 @@ from torch.optim import Optimizer
|
|||||||
from torch.optim.optimizer import required
|
from torch.optim.optimizer import required
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
import logging
|
import logging
|
||||||
from .optimization import *
|
from .optimization import SCHEDULES, _LRSchedule, WarmupCosineWithWarmupRestartsSchedule, \
|
||||||
|
WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule, WarmupLinearSchedule, WarmupConstantSchedule
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -33,7 +34,7 @@ class OpenAIAdam(Optimizer):
|
|||||||
vector_l2=False, max_grad_norm=-1, **kwargs):
|
vector_l2=False, max_grad_norm=-1, **kwargs):
|
||||||
if lr is not required and lr < 0.0:
|
if lr is not required and lr < 0.0:
|
||||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||||
if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES:
|
if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
|
||||||
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
||||||
if not 0.0 <= b1 < 1.0:
|
if not 0.0 <= b1 < 1.0:
|
||||||
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
||||||
@@ -42,7 +43,7 @@ class OpenAIAdam(Optimizer):
|
|||||||
if not e >= 0.0:
|
if not e >= 0.0:
|
||||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
||||||
# initialize schedule object
|
# initialize schedule object
|
||||||
if not isinstance(schedule, LRSchedule):
|
if not isinstance(schedule, _LRSchedule):
|
||||||
schedule_type = SCHEDULES[schedule]
|
schedule_type = SCHEDULES[schedule]
|
||||||
schedule = schedule_type(warmup=warmup, t_total=t_total)
|
schedule = schedule_type(warmup=warmup, t_total=t_total)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -21,10 +21,11 @@ import unittest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_pretrained_bert import BertAdam
|
from pytorch_pretrained_bert import BertAdam
|
||||||
from pytorch_pretrained_bert.optimization import WarmupCosineWithWarmupRestartsSchedule
|
from pytorch_pretrained_bert import OpenAIAdam
|
||||||
#from matplotlib import pyplot as plt
|
from pytorch_pretrained_bert.optimization import ConstantLR, WarmupLinearSchedule, WarmupCosineWithWarmupRestartsSchedule
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class OptimizationTest(unittest.TestCase):
|
class OptimizationTest(unittest.TestCase):
|
||||||
|
|
||||||
def assertListAlmostEqual(self, list1, list2, tol):
|
def assertListAlmostEqual(self, list1, list2, tol):
|
||||||
@@ -49,13 +50,33 @@ class OptimizationTest(unittest.TestCase):
|
|||||||
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
|
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduleInitTest(unittest.TestCase):
|
||||||
|
def test_bert_sched_init(self):
|
||||||
|
m = torch.nn.Linear(50, 50)
|
||||||
|
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
|
||||||
|
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||||
|
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
|
||||||
|
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||||
|
optim = BertAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
|
||||||
|
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
|
||||||
|
# shouldn't fail
|
||||||
|
|
||||||
|
def test_openai_sched_init(self):
|
||||||
|
m = torch.nn.Linear(50, 50)
|
||||||
|
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
|
||||||
|
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||||
|
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
|
||||||
|
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||||
|
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
|
||||||
|
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
|
||||||
|
# shouldn't fail
|
||||||
|
|
||||||
|
|
||||||
class WarmupCosineWithRestartsTest(unittest.TestCase):
|
class WarmupCosineWithRestartsTest(unittest.TestCase):
|
||||||
def test_it(self):
|
def test_it(self):
|
||||||
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5)
|
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5)
|
||||||
x = np.arange(0, 1000)
|
x = np.arange(0, 1000)
|
||||||
y = [m.get_lr(xe) for xe in x]
|
y = [m.get_lr(xe) for xe in x]
|
||||||
# plt.plot(y)
|
|
||||||
# plt.show(block=False)
|
|
||||||
y = np.asarray(y)
|
y = np.asarray(y)
|
||||||
expected_zeros = y[[0, 200, 400, 600, 800]]
|
expected_zeros = y[[0, 200, 400, 600, 800]]
|
||||||
print(expected_zeros)
|
print(expected_zeros)
|
||||||
|
|||||||
Reference in New Issue
Block a user