Add WSD scheduler (#30231)
* Added WSD scheduler. * Added tests. * Fixed errors. * Fix formatting. * CI fixes.
This commit is contained in:
committed by
GitHub
parent
90cb55bf77
commit
7b1170b0fa
@@ -36,6 +36,7 @@ if is_torch_available():
|
||||
get_inverse_sqrt_schedule,
|
||||
get_linear_schedule_with_warmup,
|
||||
get_polynomial_decay_schedule_with_warmup,
|
||||
get_wsd_schedule,
|
||||
)
|
||||
|
||||
|
||||
@@ -150,6 +151,10 @@ class ScheduleInitTest(unittest.TestCase):
|
||||
{"num_warmup_steps": 2},
|
||||
[0.0, 5.0, 10.0, 8.165, 7.071, 6.325, 5.774, 5.345, 5.0, 4.714],
|
||||
),
|
||||
get_wsd_schedule: (
|
||||
{"num_warmup_steps": 2, "num_stable_steps": 2, "num_decay_steps": 3, "min_lr_ratio": 0.1},
|
||||
[0.0, 5.0, 10.0, 10.0, 10.0, 7.75, 3.25, 1.0, 1.0, 1.0],
|
||||
),
|
||||
}
|
||||
|
||||
for scheduler_func, data in scheds.items():
|
||||
|
||||
Reference in New Issue
Block a user