optimization tests
This commit is contained in:
@@ -26,6 +26,13 @@ from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSched
|
||||
import numpy as np
|
||||
|
||||
|
||||
def unwrap_schedule(scheduler, num_steps=10):
|
||||
lrs = []
|
||||
for _ in range(num_steps):
|
||||
scheduler.step()
|
||||
lrs.append(scheduler.get_lr())
|
||||
return lrs
|
||||
|
||||
class OptimizationTest(unittest.TestCase):
|
||||
|
||||
def assertListAlmostEqual(self, list1, list2, tol):
|
||||
@@ -38,9 +45,7 @@ class OptimizationTest(unittest.TestCase):
|
||||
target = torch.tensor([0.4, 0.2, -0.5])
|
||||
criterion = torch.nn.MSELoss()
|
||||
# No warmup, constant schedule, no gradient clipping
|
||||
optimizer = AdamW(params=[w], lr=2e-1,
|
||||
weight_decay=0.0,
|
||||
max_grad_norm=-1)
|
||||
optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0)
|
||||
for _ in range(100):
|
||||
loss = criterion(w, target)
|
||||
loss.backward()
|
||||
@@ -51,29 +56,49 @@ class OptimizationTest(unittest.TestCase):
|
||||
|
||||
|
||||
class ScheduleInitTest(unittest.TestCase):
|
||||
def test_sched_init(self):
|
||||
m = torch.nn.Linear(50, 50)
|
||||
optim = AdamW(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = AdamW(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = AdamW(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
|
||||
# shouldn't fail
|
||||
m = torch.nn.Linear(50, 50)
|
||||
optimizer = AdamW(m.parameters(), lr=10.)
|
||||
num_steps = 10
|
||||
|
||||
def assertListAlmostEqual(self, list1, list2, tol):
|
||||
self.assertEqual(len(list1), len(list2))
|
||||
for a, b in zip(list1, list2):
|
||||
self.assertAlmostEqual(a, b, delta=tol)
|
||||
|
||||
class WarmupCosineWithRestartsTest(unittest.TestCase):
|
||||
def test_it(self):
|
||||
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5)
|
||||
x = np.arange(0, 1000)
|
||||
y = [m.get_lr(xe) for xe in x]
|
||||
y = np.asarray(y)
|
||||
expected_zeros = y[[0, 200, 400, 600, 800]]
|
||||
print(expected_zeros)
|
||||
expected_ones = y[[50, 250, 450, 650, 850]]
|
||||
print(expected_ones)
|
||||
self.assertTrue(np.allclose(expected_ones, 1))
|
||||
self.assertTrue(np.allclose(expected_zeros, 0))
|
||||
def test_constant_scheduler(self):
|
||||
scheduler = ConstantLRSchedule(self.optimizer)
|
||||
lrs = unwrap_schedule(scheduler, self.num_steps)
|
||||
expected_learning_rates = [10.] * self.num_steps
|
||||
self.assertEqual(len(lrs[0]), 1)
|
||||
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
|
||||
|
||||
def test_warmup_constant_scheduler(self):
|
||||
scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4)
|
||||
lrs = unwrap_schedule(scheduler, self.num_steps)
|
||||
expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
|
||||
self.assertEqual(len(lrs[0]), 1)
|
||||
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
|
||||
|
||||
def test_warmup_linear_scheduler(self):
|
||||
scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10)
|
||||
lrs = unwrap_schedule(scheduler, self.num_steps)
|
||||
expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0]
|
||||
self.assertEqual(len(lrs[0]), 1)
|
||||
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
|
||||
|
||||
def test_warmup_cosine_scheduler(self):
|
||||
scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10)
|
||||
lrs = unwrap_schedule(scheduler, self.num_steps)
|
||||
expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0]
|
||||
self.assertEqual(len(lrs[0]), 1)
|
||||
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
|
||||
|
||||
def test_warmup_cosine_hard_restart_scheduler(self):
|
||||
scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10)
|
||||
lrs = unwrap_schedule(scheduler, self.num_steps)
|
||||
expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0]
|
||||
self.assertEqual(len(lrs[0]), 1)
|
||||
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user