From ece0903e111b96718a0d845fa358c9581d03f15d Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 11 Aug 2020 14:56:41 -0700 Subject: [PATCH] lr_schedulers: add get_polynomial_decay_schedule_with_warmup (#6361) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [wip] add get_polynomial_decay_schedule_with_warmup * style * add assert * change lr_end to a much smaller default number * check for exact equality * [model_cards] electra-base-turkish-cased-ner (#6350) * for electra-base-turkish-cased-ner * Add metadata Co-authored-by: Julien Chaumond * Temporarily de-activate TPU CI * Update modeling_tf_utils.py (#6372) fix typo: ckeckpoint->checkpoint * the test now works again (#6371) * correct pl link in readme (#6364) * refactor almost identical tests (#6339) * refactor almost identical tests * important to add a clear assert error message * make the assert error even more descriptive than the original bt * Small docfile fixes (#6328) * Patch models (#6326) * TFAlbertFor{TokenClassification, MultipleChoice} * Patch models * BERT and TF BERT info s * Update check_repo * Ci GitHub caching (#6382) * Cache Github Actions CI * Remove useless file * Colab button (#6389) * Add colab button * Add colab link for tutorials * Fix links for open in colab (#6391) * Update src/transformers/optimization.py consistently use lr_end=1e-7 default Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * [wip] add get_polynomial_decay_schedule_with_warmup * style * add assert * change lr_end to a much smaller default number * check for exact equality * Update src/transformers/optimization.py consistently use lr_end=1e-7 default Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * remove dup (leftover from merge) * convert the test into the new refactored format * stick to using the current_step as is, without ++ Co-authored-by: M. Yusuf Sarıgöz Co-authored-by: Julien Chaumond Co-authored-by: Lysandre Co-authored-by: Alexander Measure Co-authored-by: Rohit Gupta Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre Debut --- examples/lightning_base.py | 3 ++- src/transformers/__init__.py | 1 + src/transformers/optimization.py | 46 ++++++++++++++++++++++++++++++++ tests/test_optimization.py | 5 ++++ 4 files changed, 54 insertions(+), 1 deletion(-) diff --git a/examples/lightning_base.py b/examples/lightning_base.py index 11a4ded828..bc226f6ba8 100644 --- a/examples/lightning_base.py +++ b/examples/lightning_base.py @@ -25,6 +25,7 @@ from transformers.optimization import ( get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, ) @@ -48,7 +49,7 @@ arg_to_scheduler = { "linear": get_linear_schedule_with_warmup, "cosine": get_cosine_schedule_with_warmup, "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, - # polynomial': '', # TODO + "polynomial": get_polynomial_decay_schedule_with_warmup, # '': get_constant_schedule, # not supported for now # '': get_constant_schedule_with_warmup, # not supported for now } diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9456505c31..317d66baea 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -428,6 +428,7 @@ if is_torch_available(): get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, ) # Trainer diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index 9419dc76c8..410e54f0e8 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -165,6 +165,52 @@ def get_cosine_with_hard_restarts_schedule_with_warmup( return LambdaLR(optimizer, lr_lambda, last_epoch) +def get_polynomial_decay_schedule_with_warmup( + optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=2.0, last_epoch=-1 +): + """ + Create a schedule with a learning rate that decreases as a polynomial decay + from the initial lr set in the optimizer to end lr defined by `lr_end`, + after a warmup period during which it increases linearly from 0 to the + initial lr set in the optimizer. + + Args: + optimizer (:class:`~torch.optim.Optimizer`): + The optimizer for which to schedule the learning rate. + num_warmup_steps (:obj:`int`): + The number of steps for the warmup phase. + num_training_steps (:obj:`int`): + The total number of training steps. + lr_end (:obj:`float`, `optional`, defaults to 1e-7): + The end LR. + power (:obj:`float`, `optional`, defaults to 1.0): + Power factor. + last_epoch (:obj:`int`, `optional`, defaults to -1): + The index of the last epoch when resuming training. + + Return: + :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + + """ + + lr_init = optimizer.defaults["lr"] + assert lr_init > lr_end, f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})" + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif current_step > num_training_steps: + return lr_end / lr_init # as LambdaLR multiplies by lr_init + else: + lr_range = lr_init - lr_end + decay_steps = num_training_steps - num_warmup_steps + pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps + decay = lr_range * pct_remaining ** power + lr_end + return decay / lr_init # as LambdaLR multiplies by lr_init + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + class AdamW(Optimizer): """ Implements Adam algorithm with weight decay fix as introduced in diff --git a/tests/test_optimization.py b/tests/test_optimization.py index 65687a043e..5ab90dc0f5 100644 --- a/tests/test_optimization.py +++ b/tests/test_optimization.py @@ -32,6 +32,7 @@ if is_torch_available(): get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, ) @@ -114,6 +115,10 @@ class ScheduleInitTest(unittest.TestCase): {**common_kwargs, "num_cycles": 2}, [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0], ), + get_polynomial_decay_schedule_with_warmup: ( + {**common_kwargs, "power": 2.0, "lr_end": 1e-7}, + [5.0, 10.0, 7.656, 5.625, 3.906, 2.5, 1.406, 0.625, 0.156, 1e-07], + ), } for scheduler_func, data in scheds.items():