From 48a309d0d21383f72a461d3c5e9b4c639f373bb9 Mon Sep 17 00:00:00 2001 From: Jingze Shi Date: Mon, 10 Feb 2025 20:21:55 +0800 Subject: [PATCH] Support constant lr with cooldown (#35453) * Add support for constant learning rate with cooldown * Add support for constant learning rate with cooldown * Add support for constant learning rate with cooldown * Add support for constant learning rate with cooldown * Add support for constant learning rate with cooldown * Add support for constant learning rate with cooldown * Add support for constant learning rate with cooldown * Add more warmup and cooldown methods to 'get_wsc_schedule' * Add more warmup and cooldown methods to 'get_wsc_schedule' * Add more warmup and cooldown methods to 'get_wsc_schedule' * Add more warmup and cooldown methods to 'get_wsc_schedule' * Add more warmup and decay methods to 'get_wsd_schedule' * support num_training_steps and num_stable_steps for get_wsd_schedule * support num_training_steps and num_stable_steps for get_wsd_schedule * get wsd scheduler before the `num_training_steps` decision * fix code_quality * Update stable branch logic * fix code_quality * Move stable stage decide to `get_wsd_schedule` * Update docstring of `get_wsd_schedule` * Update `num_train_steps` to optional * Update `num_train_steps` to optional * Update docstring of `get_wsd_schedule` * Update src/transformers/optimization.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/optimization.py | 75 ++++++++++++++++++++----- tests/optimization/test_optimization.py | 28 +++++++-- 2 files changed, 86 insertions(+), 17 deletions(-) diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index 0ca5d36d0f..d00c65925e 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -393,45 +393,71 @@ def _get_wsd_scheduler_lambda( num_warmup_steps: int, num_stable_steps: int, num_decay_steps: int, - num_cycles: float, + warmup_type: str, + decay_type: str, min_lr_ratio: float, + num_cycles: float, ): if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step) / float(max(1, num_warmup_steps)) + if warmup_type == "linear": + factor = progress + elif warmup_type == "cosine": + factor = 0.5 * (1.0 - math.cos(math.pi * progress)) + elif warmup_type == "1-sqrt": + factor = 1.0 - math.sqrt(1.0 - progress) + factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio + return max(0.0, factor) + if current_step < num_warmup_steps + num_stable_steps: return 1.0 + if current_step < num_warmup_steps + num_stable_steps + num_decay_steps: progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps)) - value = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) - return (1.0 - min_lr_ratio) * value + min_lr_ratio + if decay_type == "linear": + factor = 1.0 - progress + elif decay_type == "cosine": + factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + elif decay_type == "1-sqrt": + factor = 1.0 - math.sqrt(progress) + factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio + return max(0.0, factor) return min_lr_ratio def get_wsd_schedule( optimizer: Optimizer, num_warmup_steps: int, - num_stable_steps: int, num_decay_steps: int, + num_training_steps: Optional[int] = None, + num_stable_steps: Optional[int] = None, + warmup_type: str = "linear", + decay_type: str = "cosine", min_lr_ratio: float = 0, num_cycles: float = 0.5, last_epoch: int = -1, ): """ Create a schedule with a learning rate that has three stages: - 1. linear increase from 0 to initial lr. - 2. constant lr (equal to initial lr). - 3. decrease following the values of the cosine function between the initial lr set in the optimizer to - a fraction of initial lr. + 1. warmup: increase from min_lr_ratio times the initial learning rate to the initial learning rate following a warmup_type. + 2. stable: constant learning rate. + 3. decay: decrease from the initial learning rate to min_lr_ratio times the initial learning rate following a decay_type. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. - num_stable_steps (`int`): - The number of steps for the stable phase. num_decay_steps (`int`): - The number of steps for the cosine annealing phase. + The number of steps for the decay phase. + num_training_steps (`int`, *optional*): + The total number of training steps. This is the sum of the warmup, stable and decay steps. If `num_stable_steps` is not provided, the stable phase will be `num_training_steps - num_warmup_steps - num_decay_steps`. + num_stable_steps (`int`, *optional*): + The number of steps for the stable phase. Please ensure that `num_warmup_steps + num_stable_steps + num_decay_steps` equals `num_training_steps`, otherwise the other steps will default to the minimum learning rate. + warmup_type (`str`, *optional*, defaults to "linear"): + The type of warmup to use. Can be 'linear', 'cosine' or '1-sqrt'. + decay_type (`str`, *optional*, defaults to "cosine"): + The type of decay to use. Can be 'linear', 'cosine' or '1-sqrt'. min_lr_ratio (`float`, *optional*, defaults to 0): The minimum learning rate as a ratio of the initial learning rate. num_cycles (`float`, *optional*, defaults to 0.5): @@ -443,11 +469,29 @@ def get_wsd_schedule( Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ + + if num_training_steps is None and num_stable_steps is None: + raise ValueError("Either num_training_steps or num_stable_steps must be specified.") + + if num_training_steps is not None and num_stable_steps is not None: + warnings.warn("Both num_training_steps and num_stable_steps are specified. num_stable_steps will be used.") + + if warmup_type not in ["linear", "cosine", "1-sqrt"]: + raise ValueError(f"Unknown warmup type: {warmup_type}, expected 'linear', 'cosine' or '1-sqrt'") + + if decay_type not in ["linear", "cosine", "1-sqrt"]: + raise ValueError(f"Unknown decay type: {decay_type}, expected 'linear', 'cosine' or '1-sqrt'") + + if num_stable_steps is None: + num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps + lr_lambda = partial( _get_wsd_scheduler_lambda, num_warmup_steps=num_warmup_steps, num_stable_steps=num_stable_steps, num_decay_steps=num_decay_steps, + warmup_type=warmup_type, + decay_type=decay_type, min_lr_ratio=min_lr_ratio, num_cycles=num_cycles, ) @@ -541,7 +585,12 @@ def get_scheduler( return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) if name == SchedulerType.WARMUP_STABLE_DECAY: - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **scheduler_specific_kwargs) + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **scheduler_specific_kwargs, + ) # All other schedulers require `num_training_steps` if num_training_steps is None: diff --git a/tests/optimization/test_optimization.py b/tests/optimization/test_optimization.py index 6982583d2b..4ab248e75a 100644 --- a/tests/optimization/test_optimization.py +++ b/tests/optimization/test_optimization.py @@ -153,8 +153,8 @@ class ScheduleInitTest(unittest.TestCase): [0.0, 5.0, 10.0, 8.165, 7.071, 6.325, 5.774, 5.345, 5.0, 4.714], ), get_wsd_schedule: ( - {"num_warmup_steps": 2, "num_stable_steps": 2, "num_decay_steps": 3, "min_lr_ratio": 0.1}, - [0.0, 5.0, 10.0, 10.0, 10.0, 7.75, 3.25, 1.0, 1.0, 1.0], + {**common_kwargs, "num_decay_steps": 2, "min_lr_ratio": 0.0}, + [0.0, 5.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 5.0], ), } @@ -183,14 +183,34 @@ class ScheduleInitTest(unittest.TestCase): "name": "warmup_stable_decay", "optimizer": self.optimizer, "num_warmup_steps": 2, - "scheduler_specific_kwargs": {"num_stable_steps": 1, "num_decay_steps": 3}, + "num_training_steps": 10, + "scheduler_specific_kwargs": { + "num_decay_steps": 2, + "warmup_type": "linear", + "decay_type": "linear", + }, }, { "name": "warmup_stable_decay", "optimizer": self.optimizer, "num_warmup_steps": 2, "num_training_steps": 10, - "scheduler_specific_kwargs": {"num_stable_steps": 1, "num_decay_steps": 3}, + "scheduler_specific_kwargs": { + "num_decay_steps": 2, + "warmup_type": "cosine", + "decay_type": "cosine", + }, + }, + { + "name": "warmup_stable_decay", + "optimizer": self.optimizer, + "num_warmup_steps": 2, + "num_training_steps": 10, + "scheduler_specific_kwargs": { + "num_decay_steps": 2, + "warmup_type": "1-sqrt", + "decay_type": "1-sqrt", + }, }, {"name": "cosine", "optimizer": self.optimizer, "num_warmup_steps": 2, "num_training_steps": 10}, ]