Add WSD scheduler (#30231)
* Added WSD scheduler. * Added tests. * Fixed errors. * Fix formatting. * CI fixes.
This commit is contained in:
committed by
GitHub
parent
90cb55bf77
commit
7b1170b0fa
@@ -66,6 +66,8 @@ The `.optimization` module provides:
|
|||||||
|
|
||||||
[[autodoc]] get_inverse_sqrt_schedule
|
[[autodoc]] get_inverse_sqrt_schedule
|
||||||
|
|
||||||
|
[[autodoc]] get_wsd_schedule
|
||||||
|
|
||||||
### Warmup (TensorFlow)
|
### Warmup (TensorFlow)
|
||||||
|
|
||||||
[[autodoc]] WarmUp
|
[[autodoc]] WarmUp
|
||||||
|
|||||||
@@ -3911,6 +3911,7 @@ else:
|
|||||||
"get_linear_schedule_with_warmup",
|
"get_linear_schedule_with_warmup",
|
||||||
"get_polynomial_decay_schedule_with_warmup",
|
"get_polynomial_decay_schedule_with_warmup",
|
||||||
"get_scheduler",
|
"get_scheduler",
|
||||||
|
"get_wsd_schedule",
|
||||||
]
|
]
|
||||||
_import_structure["pytorch_utils"] = [
|
_import_structure["pytorch_utils"] = [
|
||||||
"Conv1D",
|
"Conv1D",
|
||||||
@@ -8414,6 +8415,7 @@ if TYPE_CHECKING:
|
|||||||
get_linear_schedule_with_warmup,
|
get_linear_schedule_with_warmup,
|
||||||
get_polynomial_decay_schedule_with_warmup,
|
get_polynomial_decay_schedule_with_warmup,
|
||||||
get_scheduler,
|
get_scheduler,
|
||||||
|
get_wsd_schedule,
|
||||||
)
|
)
|
||||||
from .pytorch_utils import Conv1D, apply_chunking_to_forward, prune_layer
|
from .pytorch_utils import Conv1D, apply_chunking_to_forward, prune_layer
|
||||||
|
|
||||||
|
|||||||
@@ -387,6 +387,73 @@ def get_cosine_with_min_lr_schedule_with_warmup(
|
|||||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_wsd_scheduler_lambda(
|
||||||
|
current_step: int,
|
||||||
|
*,
|
||||||
|
num_warmup_steps: int,
|
||||||
|
num_stable_steps: int,
|
||||||
|
num_decay_steps: int,
|
||||||
|
num_cycles: float,
|
||||||
|
min_lr_ratio: float,
|
||||||
|
):
|
||||||
|
if current_step < num_warmup_steps:
|
||||||
|
return float(current_step) / float(max(1, num_warmup_steps))
|
||||||
|
if current_step < num_warmup_steps + num_stable_steps:
|
||||||
|
return 1.0
|
||||||
|
if current_step < num_warmup_steps + num_stable_steps + num_decay_steps:
|
||||||
|
progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps))
|
||||||
|
value = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
||||||
|
return (1.0 - min_lr_ratio) * value + min_lr_ratio
|
||||||
|
return min_lr_ratio
|
||||||
|
|
||||||
|
|
||||||
|
def get_wsd_schedule(
|
||||||
|
optimizer: Optimizer,
|
||||||
|
num_warmup_steps: int,
|
||||||
|
num_stable_steps: int,
|
||||||
|
num_decay_steps: int,
|
||||||
|
min_lr_ratio: float = 0,
|
||||||
|
num_cycles: float = 0.5,
|
||||||
|
last_epoch: int = -1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a schedule with a learning rate that has three stages:
|
||||||
|
1. linear increase from 0 to initial lr.
|
||||||
|
2. constant lr (equal to initial lr).
|
||||||
|
3. decrease following the values of the cosine function between the initial lr set in the optimizer to
|
||||||
|
a fraction of initial lr.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer ([`~torch.optim.Optimizer`]):
|
||||||
|
The optimizer for which to schedule the learning rate.
|
||||||
|
num_warmup_steps (`int`):
|
||||||
|
The number of steps for the warmup phase.
|
||||||
|
num_stable_steps (`int`):
|
||||||
|
The number of steps for the stable phase.
|
||||||
|
num_decay_steps (`int`):
|
||||||
|
The number of steps for the cosine annealing phase.
|
||||||
|
min_lr_ratio (`float`, *optional*, defaults to 0):
|
||||||
|
The minimum learning rate as a ratio of the initial learning rate.
|
||||||
|
num_cycles (`float`, *optional*, defaults to 0.5):
|
||||||
|
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
||||||
|
following a half-cosine).
|
||||||
|
last_epoch (`int`, *optional*, defaults to -1):
|
||||||
|
The index of the last epoch when resuming training.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||||
|
"""
|
||||||
|
lr_lambda = partial(
|
||||||
|
_get_wsd_scheduler_lambda,
|
||||||
|
num_warmup_steps=num_warmup_steps,
|
||||||
|
num_stable_steps=num_stable_steps,
|
||||||
|
num_decay_steps=num_decay_steps,
|
||||||
|
min_lr_ratio=min_lr_ratio,
|
||||||
|
num_cycles=num_cycles,
|
||||||
|
)
|
||||||
|
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||||
|
|
||||||
|
|
||||||
TYPE_TO_SCHEDULER_FUNCTION = {
|
TYPE_TO_SCHEDULER_FUNCTION = {
|
||||||
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
|
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
|
||||||
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
|
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
|
||||||
@@ -397,6 +464,7 @@ TYPE_TO_SCHEDULER_FUNCTION = {
|
|||||||
SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
|
SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
|
||||||
SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule,
|
SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule,
|
||||||
SchedulerType.COSINE_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup,
|
SchedulerType.COSINE_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup,
|
||||||
|
SchedulerType.WARMUP_STABLE_DECAY: get_wsd_schedule,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -412,6 +412,7 @@ class SchedulerType(ExplicitEnum):
|
|||||||
INVERSE_SQRT = "inverse_sqrt"
|
INVERSE_SQRT = "inverse_sqrt"
|
||||||
REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"
|
REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"
|
||||||
COSINE_WITH_MIN_LR = "cosine_with_min_lr"
|
COSINE_WITH_MIN_LR = "cosine_with_min_lr"
|
||||||
|
WARMUP_STABLE_DECAY = "warmup_stable_decay"
|
||||||
|
|
||||||
|
|
||||||
class TrainerMemoryTracker:
|
class TrainerMemoryTracker:
|
||||||
|
|||||||
@@ -10023,6 +10023,10 @@ def get_scheduler(*args, **kwargs):
|
|||||||
requires_backends(get_scheduler, ["torch"])
|
requires_backends(get_scheduler, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_wsd_schedule(*args, **kwargs):
|
||||||
|
requires_backends(get_wsd_schedule, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class Conv1D(metaclass=DummyObject):
|
class Conv1D(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ if is_torch_available():
|
|||||||
get_inverse_sqrt_schedule,
|
get_inverse_sqrt_schedule,
|
||||||
get_linear_schedule_with_warmup,
|
get_linear_schedule_with_warmup,
|
||||||
get_polynomial_decay_schedule_with_warmup,
|
get_polynomial_decay_schedule_with_warmup,
|
||||||
|
get_wsd_schedule,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -150,6 +151,10 @@ class ScheduleInitTest(unittest.TestCase):
|
|||||||
{"num_warmup_steps": 2},
|
{"num_warmup_steps": 2},
|
||||||
[0.0, 5.0, 10.0, 8.165, 7.071, 6.325, 5.774, 5.345, 5.0, 4.714],
|
[0.0, 5.0, 10.0, 8.165, 7.071, 6.325, 5.774, 5.345, 5.0, 4.714],
|
||||||
),
|
),
|
||||||
|
get_wsd_schedule: (
|
||||||
|
{"num_warmup_steps": 2, "num_stable_steps": 2, "num_decay_steps": 3, "min_lr_ratio": 0.1},
|
||||||
|
[0.0, 5.0, 10.0, 10.0, 10.0, 7.75, 3.25, 1.0, 1.0, 1.0],
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
for scheduler_func, data in scheds.items():
|
for scheduler_func, data in scheds.items():
|
||||||
|
|||||||
Reference in New Issue
Block a user