refactor almost identical tests (#6339)
* refactor almost identical tests * important to add a clear assert error message * make the assert error even more descriptive than the original bt
This commit is contained in:
@@ -86,66 +86,51 @@ class ScheduleInitTest(unittest.TestCase):
|
|||||||
optimizer = AdamW(m.parameters(), lr=10.0) if is_torch_available() else None
|
optimizer = AdamW(m.parameters(), lr=10.0) if is_torch_available() else None
|
||||||
num_steps = 10
|
num_steps = 10
|
||||||
|
|
||||||
def assertListAlmostEqual(self, list1, list2, tol):
|
def assertListAlmostEqual(self, list1, list2, tol, msg=None):
|
||||||
self.assertEqual(len(list1), len(list2))
|
self.assertEqual(len(list1), len(list2))
|
||||||
for a, b in zip(list1, list2):
|
for a, b in zip(list1, list2):
|
||||||
self.assertAlmostEqual(a, b, delta=tol)
|
self.assertAlmostEqual(a, b, delta=tol, msg=msg)
|
||||||
|
|
||||||
def test_constant_scheduler(self):
|
def test_schedulers(self):
|
||||||
scheduler = get_constant_schedule(self.optimizer)
|
|
||||||
lrs = unwrap_schedule(scheduler, self.num_steps)
|
|
||||||
expected_learning_rates = [10.0] * self.num_steps
|
|
||||||
self.assertEqual(len(lrs[0]), 1)
|
|
||||||
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
|
|
||||||
|
|
||||||
scheduler = get_constant_schedule(self.optimizer)
|
common_kwargs = {"num_warmup_steps": 2, "num_training_steps": 10}
|
||||||
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
|
# schedulers doct format
|
||||||
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
|
# function: (sched_args_dict, expected_learning_rates)
|
||||||
|
scheds = {
|
||||||
|
get_constant_schedule: ({}, [10.0] * self.num_steps),
|
||||||
|
get_constant_schedule_with_warmup: (
|
||||||
|
{"num_warmup_steps": 4},
|
||||||
|
[2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0],
|
||||||
|
),
|
||||||
|
get_linear_schedule_with_warmup: (
|
||||||
|
{**common_kwargs},
|
||||||
|
[5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0],
|
||||||
|
),
|
||||||
|
get_cosine_schedule_with_warmup: (
|
||||||
|
{**common_kwargs},
|
||||||
|
[5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0],
|
||||||
|
),
|
||||||
|
get_cosine_with_hard_restarts_schedule_with_warmup: (
|
||||||
|
{**common_kwargs, "num_cycles": 2},
|
||||||
|
[5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
def test_warmup_constant_scheduler(self):
|
for scheduler_func, data in scheds.items():
|
||||||
scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4)
|
kwargs, expected_learning_rates = data
|
||||||
lrs = unwrap_schedule(scheduler, self.num_steps)
|
|
||||||
expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
|
|
||||||
self.assertEqual(len(lrs[0]), 1)
|
|
||||||
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
|
|
||||||
|
|
||||||
scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4)
|
scheduler = scheduler_func(self.optimizer, **kwargs)
|
||||||
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
|
lrs_1 = unwrap_schedule(scheduler, self.num_steps)
|
||||||
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
|
self.assertEqual(len(lrs_1[0]), 1)
|
||||||
|
self.assertListAlmostEqual(
|
||||||
def test_warmup_linear_scheduler(self):
|
[l[0] for l in lrs_1],
|
||||||
scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
|
expected_learning_rates,
|
||||||
lrs = unwrap_schedule(scheduler, self.num_steps)
|
tol=1e-2,
|
||||||
expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0]
|
msg=f"failed for {scheduler_func} in normal scheduler",
|
||||||
self.assertEqual(len(lrs[0]), 1)
|
|
||||||
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
|
|
||||||
|
|
||||||
scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
|
|
||||||
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
|
|
||||||
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
|
|
||||||
|
|
||||||
def test_warmup_cosine_scheduler(self):
|
|
||||||
scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
|
|
||||||
lrs = unwrap_schedule(scheduler, self.num_steps)
|
|
||||||
expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0]
|
|
||||||
self.assertEqual(len(lrs[0]), 1)
|
|
||||||
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
|
|
||||||
|
|
||||||
scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
|
|
||||||
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
|
|
||||||
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
|
|
||||||
|
|
||||||
def test_warmup_cosine_hard_restart_scheduler(self):
|
|
||||||
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
|
|
||||||
self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10
|
|
||||||
)
|
)
|
||||||
lrs = unwrap_schedule(scheduler, self.num_steps)
|
|
||||||
expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0]
|
|
||||||
self.assertEqual(len(lrs[0]), 1)
|
|
||||||
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
|
|
||||||
|
|
||||||
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
|
scheduler = scheduler_func(self.optimizer, **kwargs)
|
||||||
self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10
|
|
||||||
)
|
|
||||||
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
|
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
|
||||||
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
|
self.assertListEqual(
|
||||||
|
[l[0] for l in lrs_1], [l[0] for l in lrs_2], msg=f"failed for {scheduler_func} in save and reload"
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user