replace LambdaLR scheduler wrappers by function

Custom schedulers are currently initiated by wrapping Pytorch's LambdaLR
class and passing a method of the wrapping class to the __init__
function of LambdaLR. This approach is not appropriate for several
reasons:

1. one does not need to define a class when it only defines a
__init__() method;
2. instantiating the parent class by passing a method of the child class
creates a cyclical reference which leads to memory leaks. See issues #1742 and #1134.

In this commit we replace the wrapper classes with functions that
instantiate `LambdaLR` with a custom learning rate function. We use a
closure to specify the parameter of the latter. We also do a bit of
renaming within the function to explicit the behaviour and removed
docstrings that were subsequently not necessary.
This commit is contained in:
Rémi Louf
2019-11-12 11:08:47 +01:00
parent 1c542df7e5
commit 022525b003
3 changed files with 61 additions and 80 deletions

View File

@@ -25,8 +25,12 @@ from transformers import is_torch_available
if is_torch_available():
import torch
from transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
from transformers import (AdamW,
get_constant_schedule,
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup)
else:
pytestmark = pytest.mark.skip("Require Torch")
@@ -87,59 +91,60 @@ class ScheduleInitTest(unittest.TestCase):
self.assertAlmostEqual(a, b, delta=tol)
def test_constant_scheduler(self):
scheduler = ConstantLRSchedule(self.optimizer)
scheduler = get_constant_schedule(self.optimizer)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [10.] * self.num_steps
self.assertEqual(len(lrs[0]), 1)
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
scheduler = ConstantLRSchedule(self.optimizer)
scheduler = get_constant_schedule(self.optimizer)
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
def test_warmup_constant_scheduler(self):
scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4)
scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4)
scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4)
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
def test_warmup_linear_scheduler(self):
scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10)
scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10)
scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
def test_warmup_cosine_scheduler(self):
scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10)
scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10)
scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
def test_warmup_cosine_hard_restart_scheduler(self):
scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10)
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10)
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10)
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
if __name__ == "__main__":
unittest.main()