diff --git a/examples/lightning_base.py b/examples/lightning_base.py index 11a4ded828..bc226f6ba8 100644 --- a/examples/lightning_base.py +++ b/examples/lightning_base.py @@ -25,6 +25,7 @@ from transformers.optimization import ( get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, ) @@ -48,7 +49,7 @@ arg_to_scheduler = { "linear": get_linear_schedule_with_warmup, "cosine": get_cosine_schedule_with_warmup, "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, - # polynomial': '', # TODO + "polynomial": get_polynomial_decay_schedule_with_warmup, # '': get_constant_schedule, # not supported for now # '': get_constant_schedule_with_warmup, # not supported for now } diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9456505c31..317d66baea 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -428,6 +428,7 @@ if is_torch_available(): get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, ) # Trainer diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index 9419dc76c8..410e54f0e8 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -165,6 +165,52 @@ def get_cosine_with_hard_restarts_schedule_with_warmup( return LambdaLR(optimizer, lr_lambda, last_epoch) +def get_polynomial_decay_schedule_with_warmup( + optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=2.0, last_epoch=-1 +): + """ + Create a schedule with a learning rate that decreases as a polynomial decay + from the initial lr set in the optimizer to end lr defined by `lr_end`, + after a warmup period during which it increases linearly from 0 to the + initial lr set in the optimizer. + + Args: + optimizer (:class:`~torch.optim.Optimizer`): + The optimizer for which to schedule the learning rate. + num_warmup_steps (:obj:`int`): + The number of steps for the warmup phase. + num_training_steps (:obj:`int`): + The total number of training steps. + lr_end (:obj:`float`, `optional`, defaults to 1e-7): + The end LR. + power (:obj:`float`, `optional`, defaults to 1.0): + Power factor. + last_epoch (:obj:`int`, `optional`, defaults to -1): + The index of the last epoch when resuming training. + + Return: + :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + + """ + + lr_init = optimizer.defaults["lr"] + assert lr_init > lr_end, f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})" + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif current_step > num_training_steps: + return lr_end / lr_init # as LambdaLR multiplies by lr_init + else: + lr_range = lr_init - lr_end + decay_steps = num_training_steps - num_warmup_steps + pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps + decay = lr_range * pct_remaining ** power + lr_end + return decay / lr_init # as LambdaLR multiplies by lr_init + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + class AdamW(Optimizer): """ Implements Adam algorithm with weight decay fix as introduced in diff --git a/tests/test_optimization.py b/tests/test_optimization.py index 65687a043e..5ab90dc0f5 100644 --- a/tests/test_optimization.py +++ b/tests/test_optimization.py @@ -32,6 +32,7 @@ if is_torch_available(): get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, ) @@ -114,6 +115,10 @@ class ScheduleInitTest(unittest.TestCase): {**common_kwargs, "num_cycles": 2}, [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0], ), + get_polynomial_decay_schedule_with_warmup: ( + {**common_kwargs, "power": 2.0, "lr_end": 1e-7}, + [5.0, 10.0, 7.656, 5.625, 3.906, 2.5, 1.406, 0.625, 0.156, 1e-07], + ), } for scheduler_func, data in scheds.items():