diff --git a/tests/test_optimization.py b/tests/test_optimization.py index 54eff2ecdc..3c79c40208 100644 --- a/tests/test_optimization.py +++ b/tests/test_optimization.py @@ -40,16 +40,16 @@ if is_torch_available(): def unwrap_schedule(scheduler, num_steps=10): lrs = [] for _ in range(num_steps): + lrs.append(scheduler.get_lr()[0]) scheduler.step() - lrs.append(scheduler.get_lr()) return lrs def unwrap_and_save_reload_schedule(scheduler, num_steps=10): lrs = [] for step in range(num_steps): + lrs.append(scheduler.get_lr()[0]) scheduler.step() - lrs.append(scheduler.get_lr()) if step == num_steps // 2: with tempfile.TemporaryDirectory() as tmpdirname: file_name = os.path.join(tmpdirname, "schedule.bin") @@ -127,23 +127,23 @@ class ScheduleInitTest(unittest.TestCase): get_constant_schedule: ({}, [10.0] * self.num_steps), get_constant_schedule_with_warmup: ( {"num_warmup_steps": 4}, - [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0], + [0.0, 2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0], ), get_linear_schedule_with_warmup: ( {**common_kwargs}, - [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0], + [0.0, 5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25], ), get_cosine_schedule_with_warmup: ( {**common_kwargs}, - [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0], + [0.0, 5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38], ), get_cosine_with_hard_restarts_schedule_with_warmup: ( {**common_kwargs, "num_cycles": 2}, - [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0], + [0.0, 5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46], ), get_polynomial_decay_schedule_with_warmup: ( {**common_kwargs, "power": 2.0, "lr_end": 1e-7}, - [5.0, 10.0, 7.656, 5.625, 3.906, 2.5, 1.406, 0.625, 0.156, 1e-07], + [0.0, 5.0, 10.0, 7.656, 5.625, 3.906, 2.5, 1.406, 0.625, 0.156], ), } @@ -151,17 +151,12 @@ class ScheduleInitTest(unittest.TestCase): kwargs, expected_learning_rates = data scheduler = scheduler_func(self.optimizer, **kwargs) + self.assertEqual(len([scheduler.get_lr()[0]]), 1) lrs_1 = unwrap_schedule(scheduler, self.num_steps) - self.assertEqual(len(lrs_1[0]), 1) self.assertListAlmostEqual( - [l[0] for l in lrs_1], - expected_learning_rates, - tol=1e-2, - msg=f"failed for {scheduler_func} in normal scheduler", + lrs_1, expected_learning_rates, tol=1e-2, msg=f"failed for {scheduler_func} in normal scheduler", ) scheduler = scheduler_func(self.optimizer, **kwargs) lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) - self.assertListEqual( - [l[0] for l in lrs_1], [l[0] for l in lrs_2], msg=f"failed for {scheduler_func} in save and reload" - ) + self.assertListEqual(lrs_1, lrs_2, msg=f"failed for {scheduler_func} in save and reload")