From bb7557d3ab96f139997bfaa70ff2b4a6c18994e0 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Sun, 21 Apr 2019 13:48:33 +0200 Subject: [PATCH] - removed __all__ in optimization - removed unused plotting code - using ABC for LRSchedule - added some schedule object init tests --- pytorch_pretrained_bert/optimization.py | 30 ++++++++++--------- .../optimization_openai.py | 7 +++-- tests/optimization_test.py | 29 +++++++++++++++--- 3 files changed, 45 insertions(+), 21 deletions(-) diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index ca973015a6..d2d4f7f5e5 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -20,15 +20,12 @@ from torch.optim import Optimizer from torch.optim.optimizer import required from torch.nn.utils import clip_grad_norm_ import logging +from abc import ABC, abstractmethod logger = logging.getLogger(__name__) -__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam", - "WarmupCosineWithHardRestartsSchedule", "WarmupCosineWithWarmupRestartsSchedule", "SCHEDULES"] - - -class LRSchedule(object): +class _LRSchedule(ABC): """ Parent of all LRSchedules here. """ warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense def __init__(self, warmup=0.002, t_total=-1, **kw): @@ -37,7 +34,7 @@ class LRSchedule(object): :param t_total: how many training steps (updates) are planned :param kw: """ - super(LRSchedule, self).__init__(**kw) + super(_LRSchedule, self).__init__(**kw) if t_total < 0: logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) if not 0.0 <= warmup < 1.0 and not warmup == -1: @@ -65,16 +62,21 @@ class LRSchedule(object): # end warning return ret + @abstractmethod def get_lr_(self, progress): """ :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress :return: learning rate multiplier for current update """ return 1. - # raise NotImplemented("use subclass") - -class WarmupCosineSchedule(LRSchedule): +class ConstantLR(_LRSchedule): + def get_lr_(self, progress): + return 1. + + +class WarmupCosineSchedule(_LRSchedule): """ Cosine learning rate schedule with linear warmup. Cosine after warmup is without restarts. """ @@ -135,7 +137,7 @@ class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedul return ret -class WarmupConstantSchedule(LRSchedule): +class WarmupConstantSchedule(_LRSchedule): """ Applies linear warmup. After warmup always returns 1.. """ @@ -145,7 +147,7 @@ class WarmupConstantSchedule(LRSchedule): return 1. -class WarmupLinearSchedule(LRSchedule): +class WarmupLinearSchedule(_LRSchedule): """ Linear warmup. Linear decay after warmup. """ @@ -157,8 +159,8 @@ class WarmupLinearSchedule(LRSchedule): SCHEDULES = { - None: LRSchedule, - "none": LRSchedule, + None: ConstantLR, + "none": ConstantLR, "warmup_cosine": WarmupCosineSchedule, "warmup_constant": WarmupConstantSchedule, "warmup_linear": WarmupLinearSchedule @@ -185,7 +187,7 @@ class BertAdam(Optimizer): b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) - if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES: + if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: raise ValueError("Invalid schedule parameter: {}".format(schedule)) if not 0.0 <= b1 < 1.0: raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) @@ -194,7 +196,7 @@ class BertAdam(Optimizer): if not e >= 0.0: raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) # initialize schedule object - if not isinstance(schedule, LRSchedule): + if not isinstance(schedule, _LRSchedule): schedule_type = SCHEDULES[schedule] schedule = schedule_type(warmup=warmup, t_total=t_total) else: diff --git a/pytorch_pretrained_bert/optimization_openai.py b/pytorch_pretrained_bert/optimization_openai.py index 5bfea476a6..0cf0494e20 100644 --- a/pytorch_pretrained_bert/optimization_openai.py +++ b/pytorch_pretrained_bert/optimization_openai.py @@ -20,7 +20,8 @@ from torch.optim import Optimizer from torch.optim.optimizer import required from torch.nn.utils import clip_grad_norm_ import logging -from .optimization import * +from .optimization import SCHEDULES, _LRSchedule, WarmupCosineWithWarmupRestartsSchedule, \ + WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule, WarmupLinearSchedule, WarmupConstantSchedule logger = logging.getLogger(__name__) @@ -33,7 +34,7 @@ class OpenAIAdam(Optimizer): vector_l2=False, max_grad_norm=-1, **kwargs): if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) - if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES: + if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: raise ValueError("Invalid schedule parameter: {}".format(schedule)) if not 0.0 <= b1 < 1.0: raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) @@ -42,7 +43,7 @@ class OpenAIAdam(Optimizer): if not e >= 0.0: raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) # initialize schedule object - if not isinstance(schedule, LRSchedule): + if not isinstance(schedule, _LRSchedule): schedule_type = SCHEDULES[schedule] schedule = schedule_type(warmup=warmup, t_total=t_total) else: diff --git a/tests/optimization_test.py b/tests/optimization_test.py index e74f4bba6c..f52aeb506b 100644 --- a/tests/optimization_test.py +++ b/tests/optimization_test.py @@ -21,10 +21,11 @@ import unittest import torch from pytorch_pretrained_bert import BertAdam -from pytorch_pretrained_bert.optimization import WarmupCosineWithWarmupRestartsSchedule -#from matplotlib import pyplot as plt +from pytorch_pretrained_bert import OpenAIAdam +from pytorch_pretrained_bert.optimization import ConstantLR, WarmupLinearSchedule, WarmupCosineWithWarmupRestartsSchedule import numpy as np + class OptimizationTest(unittest.TestCase): def assertListAlmostEqual(self, list1, list2, tol): @@ -49,13 +50,33 @@ class OptimizationTest(unittest.TestCase): self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) +class ScheduleInitTest(unittest.TestCase): + def test_bert_sched_init(self): + m = torch.nn.Linear(50, 50) + optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None) + self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR)) + optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none") + self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR)) + optim = BertAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000) + self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule)) + # shouldn't fail + + def test_openai_sched_init(self): + m = torch.nn.Linear(50, 50) + optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None) + self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR)) + optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none") + self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR)) + optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000) + self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule)) + # shouldn't fail + + class WarmupCosineWithRestartsTest(unittest.TestCase): def test_it(self): m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5) x = np.arange(0, 1000) y = [m.get_lr(xe) for xe in x] - # plt.plot(y) - # plt.show(block=False) y = np.asarray(y) expected_zeros = y[[0, 200, 400, 600, 800]] print(expected_zeros)