optimization tests

This commit is contained in:
thomwolf
2019-07-11 17:39:47 +02:00
parent e4f9dca018
commit ccb6947dc1
4 changed files with 91 additions and 49 deletions

View File

@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
class ConstantLRSchedule(LambdaLR):
def __init__(self, optimizer, last_epoch=-1):
super(ConstantLRSchedule, self).__init__(optimizer, lambda x: x, last_epoch=last_epoch)
super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch)
class WarmupCosineSchedule(LambdaLR):
"""
@@ -42,10 +42,10 @@ class WarmupCosineSchedule(LambdaLR):
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
return float(step) / float(max(1.0, warmup_steps))
else:
progress = (step - warmup_steps) / max(1, t_total - warmup_steps) # progress after warmup
return 0.5 * (1. + math.cos(math.pi * cycles * 2 * progress))
progress = float(step - warmup_steps) / float(max(1, t_total - warmup_steps)) # progress after warmup
return 0.5 * (1. + math.cos(math.pi * float(cycles) * 2.0 * progress))
super(WarmupCosineSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
@@ -59,11 +59,12 @@ class WarmupCosineWithHardRestartsSchedule(LambdaLR):
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
return float(step) / float(max(1, warmup_steps))
else:
progress = (step - warmup_steps) / max(1, t_total - warmup_steps) # progress after warmup
ret = 0.5 * (1. + math.cos(math.pi * ((cycles * progress) % 1)))
return ret
progress = float(step - warmup_steps) / float(max(1, t_total - warmup_steps)) # progress after warmup
if progress >= 1.0:
return 0.0
return 0.5 * (1. + math.cos(math.pi * ((float(cycles) * progress) % 1.0)))
super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
@@ -77,7 +78,7 @@ class WarmupConstantSchedule(LambdaLR):
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
return float(step) / float(max(1.0, warmup_steps))
return 1.
super(WarmupConstantSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
@@ -92,8 +93,8 @@ class WarmupLinearSchedule(LambdaLR):
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
return (t_total - step) / max(1, t_total - warmup_steps)
return float(step) / float(max(1, warmup_steps))
return float(t_total - step) / float(max(1.0, t_total - warmup_steps))
super(WarmupLinearSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)

View File

@@ -26,6 +26,13 @@ from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSched
import numpy as np
def unwrap_schedule(scheduler, num_steps=10):
lrs = []
for _ in range(num_steps):
scheduler.step()
lrs.append(scheduler.get_lr())
return lrs
class OptimizationTest(unittest.TestCase):
def assertListAlmostEqual(self, list1, list2, tol):
@@ -38,9 +45,7 @@ class OptimizationTest(unittest.TestCase):
target = torch.tensor([0.4, 0.2, -0.5])
criterion = torch.nn.MSELoss()
# No warmup, constant schedule, no gradient clipping
optimizer = AdamW(params=[w], lr=2e-1,
weight_decay=0.0,
max_grad_norm=-1)
optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0)
for _ in range(100):
loss = criterion(w, target)
loss.backward()
@@ -51,29 +56,49 @@ class OptimizationTest(unittest.TestCase):
class ScheduleInitTest(unittest.TestCase):
def test_sched_init(self):
m = torch.nn.Linear(50, 50)
optim = AdamW(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = AdamW(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = AdamW(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
# shouldn't fail
m = torch.nn.Linear(50, 50)
optimizer = AdamW(m.parameters(), lr=10.)
num_steps = 10
def assertListAlmostEqual(self, list1, list2, tol):
self.assertEqual(len(list1), len(list2))
for a, b in zip(list1, list2):
self.assertAlmostEqual(a, b, delta=tol)
class WarmupCosineWithRestartsTest(unittest.TestCase):
def test_it(self):
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5)
x = np.arange(0, 1000)
y = [m.get_lr(xe) for xe in x]
y = np.asarray(y)
expected_zeros = y[[0, 200, 400, 600, 800]]
print(expected_zeros)
expected_ones = y[[50, 250, 450, 650, 850]]
print(expected_ones)
self.assertTrue(np.allclose(expected_ones, 1))
self.assertTrue(np.allclose(expected_zeros, 0))
def test_constant_scheduler(self):
scheduler = ConstantLRSchedule(self.optimizer)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [10.] * self.num_steps
self.assertEqual(len(lrs[0]), 1)
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
def test_warmup_constant_scheduler(self):
scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
def test_warmup_linear_scheduler(self):
scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
def test_warmup_cosine_scheduler(self):
scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
def test_warmup_cosine_hard_restart_scheduler(self):
scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
if __name__ == "__main__":