optimization tests

This commit is contained in:
thomwolf
2019-07-11 17:39:47 +02:00
parent e4f9dca018
commit ccb6947dc1
4 changed files with 91 additions and 49 deletions

View File

@@ -26,6 +26,13 @@ from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSched
import numpy as np
def unwrap_schedule(scheduler, num_steps=10):
lrs = []
for _ in range(num_steps):
scheduler.step()
lrs.append(scheduler.get_lr())
return lrs
class OptimizationTest(unittest.TestCase):
def assertListAlmostEqual(self, list1, list2, tol):
@@ -38,9 +45,7 @@ class OptimizationTest(unittest.TestCase):
target = torch.tensor([0.4, 0.2, -0.5])
criterion = torch.nn.MSELoss()
# No warmup, constant schedule, no gradient clipping
optimizer = AdamW(params=[w], lr=2e-1,
weight_decay=0.0,
max_grad_norm=-1)
optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0)
for _ in range(100):
loss = criterion(w, target)
loss.backward()
@@ -51,29 +56,49 @@ class OptimizationTest(unittest.TestCase):
class ScheduleInitTest(unittest.TestCase):
def test_sched_init(self):
m = torch.nn.Linear(50, 50)
optim = AdamW(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = AdamW(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = AdamW(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
# shouldn't fail
m = torch.nn.Linear(50, 50)
optimizer = AdamW(m.parameters(), lr=10.)
num_steps = 10
def assertListAlmostEqual(self, list1, list2, tol):
self.assertEqual(len(list1), len(list2))
for a, b in zip(list1, list2):
self.assertAlmostEqual(a, b, delta=tol)
class WarmupCosineWithRestartsTest(unittest.TestCase):
def test_it(self):
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5)
x = np.arange(0, 1000)
y = [m.get_lr(xe) for xe in x]
y = np.asarray(y)
expected_zeros = y[[0, 200, 400, 600, 800]]
print(expected_zeros)
expected_ones = y[[50, 250, 450, 650, 850]]
print(expected_ones)
self.assertTrue(np.allclose(expected_ones, 1))
self.assertTrue(np.allclose(expected_zeros, 0))
def test_constant_scheduler(self):
scheduler = ConstantLRSchedule(self.optimizer)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [10.] * self.num_steps
self.assertEqual(len(lrs[0]), 1)
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
def test_warmup_constant_scheduler(self):
scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
def test_warmup_linear_scheduler(self):
scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
def test_warmup_cosine_scheduler(self):
scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
def test_warmup_cosine_hard_restart_scheduler(self):
scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
if __name__ == "__main__":