[test schedulers] adjust to test the first step's reading (#6429)
* [test schedulers] small improvement * cleanup
This commit is contained in:
@@ -40,16 +40,16 @@ if is_torch_available():
|
|||||||
def unwrap_schedule(scheduler, num_steps=10):
|
def unwrap_schedule(scheduler, num_steps=10):
|
||||||
lrs = []
|
lrs = []
|
||||||
for _ in range(num_steps):
|
for _ in range(num_steps):
|
||||||
|
lrs.append(scheduler.get_lr()[0])
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
lrs.append(scheduler.get_lr())
|
|
||||||
return lrs
|
return lrs
|
||||||
|
|
||||||
|
|
||||||
def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
|
def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
|
||||||
lrs = []
|
lrs = []
|
||||||
for step in range(num_steps):
|
for step in range(num_steps):
|
||||||
|
lrs.append(scheduler.get_lr()[0])
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
lrs.append(scheduler.get_lr())
|
|
||||||
if step == num_steps // 2:
|
if step == num_steps // 2:
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
file_name = os.path.join(tmpdirname, "schedule.bin")
|
file_name = os.path.join(tmpdirname, "schedule.bin")
|
||||||
@@ -127,23 +127,23 @@ class ScheduleInitTest(unittest.TestCase):
|
|||||||
get_constant_schedule: ({}, [10.0] * self.num_steps),
|
get_constant_schedule: ({}, [10.0] * self.num_steps),
|
||||||
get_constant_schedule_with_warmup: (
|
get_constant_schedule_with_warmup: (
|
||||||
{"num_warmup_steps": 4},
|
{"num_warmup_steps": 4},
|
||||||
[2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0],
|
[0.0, 2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0],
|
||||||
),
|
),
|
||||||
get_linear_schedule_with_warmup: (
|
get_linear_schedule_with_warmup: (
|
||||||
{**common_kwargs},
|
{**common_kwargs},
|
||||||
[5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0],
|
[0.0, 5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25],
|
||||||
),
|
),
|
||||||
get_cosine_schedule_with_warmup: (
|
get_cosine_schedule_with_warmup: (
|
||||||
{**common_kwargs},
|
{**common_kwargs},
|
||||||
[5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0],
|
[0.0, 5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38],
|
||||||
),
|
),
|
||||||
get_cosine_with_hard_restarts_schedule_with_warmup: (
|
get_cosine_with_hard_restarts_schedule_with_warmup: (
|
||||||
{**common_kwargs, "num_cycles": 2},
|
{**common_kwargs, "num_cycles": 2},
|
||||||
[5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0],
|
[0.0, 5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46],
|
||||||
),
|
),
|
||||||
get_polynomial_decay_schedule_with_warmup: (
|
get_polynomial_decay_schedule_with_warmup: (
|
||||||
{**common_kwargs, "power": 2.0, "lr_end": 1e-7},
|
{**common_kwargs, "power": 2.0, "lr_end": 1e-7},
|
||||||
[5.0, 10.0, 7.656, 5.625, 3.906, 2.5, 1.406, 0.625, 0.156, 1e-07],
|
[0.0, 5.0, 10.0, 7.656, 5.625, 3.906, 2.5, 1.406, 0.625, 0.156],
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -151,17 +151,12 @@ class ScheduleInitTest(unittest.TestCase):
|
|||||||
kwargs, expected_learning_rates = data
|
kwargs, expected_learning_rates = data
|
||||||
|
|
||||||
scheduler = scheduler_func(self.optimizer, **kwargs)
|
scheduler = scheduler_func(self.optimizer, **kwargs)
|
||||||
|
self.assertEqual(len([scheduler.get_lr()[0]]), 1)
|
||||||
lrs_1 = unwrap_schedule(scheduler, self.num_steps)
|
lrs_1 = unwrap_schedule(scheduler, self.num_steps)
|
||||||
self.assertEqual(len(lrs_1[0]), 1)
|
|
||||||
self.assertListAlmostEqual(
|
self.assertListAlmostEqual(
|
||||||
[l[0] for l in lrs_1],
|
lrs_1, expected_learning_rates, tol=1e-2, msg=f"failed for {scheduler_func} in normal scheduler",
|
||||||
expected_learning_rates,
|
|
||||||
tol=1e-2,
|
|
||||||
msg=f"failed for {scheduler_func} in normal scheduler",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler = scheduler_func(self.optimizer, **kwargs)
|
scheduler = scheduler_func(self.optimizer, **kwargs)
|
||||||
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
|
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
|
||||||
self.assertListEqual(
|
self.assertListEqual(lrs_1, lrs_2, msg=f"failed for {scheduler_func} in save and reload")
|
||||||
[l[0] for l in lrs_1], [l[0] for l in lrs_2], msg=f"failed for {scheduler_func} in save and reload"
|
|
||||||
)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user