From 1429b920d44d610eaa0a6f48de43853da52e9c03 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 10 Aug 2020 02:31:20 -0700 Subject: [PATCH] refactor almost identical tests (#6339) * refactor almost identical tests * important to add a clear assert error message * make the assert error even more descriptive than the original bt --- tests/test_optimization.py | 97 ++++++++++++++++---------------------- 1 file changed, 41 insertions(+), 56 deletions(-) diff --git a/tests/test_optimization.py b/tests/test_optimization.py index a38e764319..65687a043e 100644 --- a/tests/test_optimization.py +++ b/tests/test_optimization.py @@ -86,66 +86,51 @@ class ScheduleInitTest(unittest.TestCase): optimizer = AdamW(m.parameters(), lr=10.0) if is_torch_available() else None num_steps = 10 - def assertListAlmostEqual(self, list1, list2, tol): + def assertListAlmostEqual(self, list1, list2, tol, msg=None): self.assertEqual(len(list1), len(list2)) for a, b in zip(list1, list2): - self.assertAlmostEqual(a, b, delta=tol) + self.assertAlmostEqual(a, b, delta=tol, msg=msg) - def test_constant_scheduler(self): - scheduler = get_constant_schedule(self.optimizer) - lrs = unwrap_schedule(scheduler, self.num_steps) - expected_learning_rates = [10.0] * self.num_steps - self.assertEqual(len(lrs[0]), 1) - self.assertListEqual([l[0] for l in lrs], expected_learning_rates) + def test_schedulers(self): - scheduler = get_constant_schedule(self.optimizer) - lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) - self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) + common_kwargs = {"num_warmup_steps": 2, "num_training_steps": 10} + # schedulers doct format + # function: (sched_args_dict, expected_learning_rates) + scheds = { + get_constant_schedule: ({}, [10.0] * self.num_steps), + get_constant_schedule_with_warmup: ( + {"num_warmup_steps": 4}, + [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0], + ), + get_linear_schedule_with_warmup: ( + {**common_kwargs}, + [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0], + ), + get_cosine_schedule_with_warmup: ( + {**common_kwargs}, + [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0], + ), + get_cosine_with_hard_restarts_schedule_with_warmup: ( + {**common_kwargs, "num_cycles": 2}, + [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0], + ), + } - def test_warmup_constant_scheduler(self): - scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4) - lrs = unwrap_schedule(scheduler, self.num_steps) - expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0] - self.assertEqual(len(lrs[0]), 1) - self.assertListEqual([l[0] for l in lrs], expected_learning_rates) + for scheduler_func, data in scheds.items(): + kwargs, expected_learning_rates = data - scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4) - lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) - self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) + scheduler = scheduler_func(self.optimizer, **kwargs) + lrs_1 = unwrap_schedule(scheduler, self.num_steps) + self.assertEqual(len(lrs_1[0]), 1) + self.assertListAlmostEqual( + [l[0] for l in lrs_1], + expected_learning_rates, + tol=1e-2, + msg=f"failed for {scheduler_func} in normal scheduler", + ) - def test_warmup_linear_scheduler(self): - scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) - lrs = unwrap_schedule(scheduler, self.num_steps) - expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0] - self.assertEqual(len(lrs[0]), 1) - self.assertListEqual([l[0] for l in lrs], expected_learning_rates) - - scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) - lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) - self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) - - def test_warmup_cosine_scheduler(self): - scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) - lrs = unwrap_schedule(scheduler, self.num_steps) - expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0] - self.assertEqual(len(lrs[0]), 1) - self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) - - scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) - lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) - self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) - - def test_warmup_cosine_hard_restart_scheduler(self): - scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( - self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10 - ) - lrs = unwrap_schedule(scheduler, self.num_steps) - expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0] - self.assertEqual(len(lrs[0]), 1) - self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) - - scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( - self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10 - ) - lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) - self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) + scheduler = scheduler_func(self.optimizer, **kwargs) + lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) + self.assertListEqual( + [l[0] for l in lrs_1], [l[0] for l in lrs_2], msg=f"failed for {scheduler_func} in save and reload" + )