Add WSD scheduler (#30231)

* Added WSD scheduler.

* Added tests.

* Fixed errors.

* Fix formatting.

* CI fixes.
This commit is contained in:
Alexander Visheratin
2024-04-25 07:07:21 -04:00
committed by GitHub
parent 90cb55bf77
commit 7b1170b0fa
6 changed files with 82 additions and 0 deletions

View File

@@ -36,6 +36,7 @@ if is_torch_available():
get_inverse_sqrt_schedule,
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
get_wsd_schedule,
)
@@ -150,6 +151,10 @@ class ScheduleInitTest(unittest.TestCase):
{"num_warmup_steps": 2},
[0.0, 5.0, 10.0, 8.165, 7.071, 6.325, 5.774, 5.345, 5.0, 4.714],
),
get_wsd_schedule: (
{"num_warmup_steps": 2, "num_stable_steps": 2, "num_decay_steps": 3, "min_lr_ratio": 0.1},
[0.0, 5.0, 10.0, 10.0, 10.0, 7.75, 3.25, 1.0, 1.0, 1.0],
),
}
for scheduler_func, data in scheds.items():