Make schedulers picklable by making lr_lambda fns global (#21768)

* Make schedulers picklable by making lr_lambda fns global

* add unused _get_constant_schedule_lr_lambda arg

* remove unneeded _get_constant_schedule_lr_lamda

* add test

* make style

* rebase, remove torch dep, put lambda back

* repo-consistency and style
This commit is contained in:
Connor Henderson
2023-03-02 12:08:43 -05:00
committed by GitHub
parent 6bf885375a
commit 8e5a1b2abb
2 changed files with 106 additions and 45 deletions

View File

@@ -166,5 +166,21 @@ class ScheduleInitTest(unittest.TestCase):
)
scheduler = scheduler_func(self.optimizer, **kwargs)
if scheduler_func.__name__ != "get_constant_schedule":
LambdaScheduleWrapper.wrap_scheduler(scheduler) # wrap to test picklability of the schedule
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
self.assertListEqual(lrs_1, lrs_2, msg=f"failed for {scheduler_func} in save and reload")
class LambdaScheduleWrapper:
"""See https://github.com/huggingface/transformers/issues/21689"""
def __init__(self, fn):
self.fn = fn
def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)
@classmethod
def wrap_scheduler(self, scheduler):
scheduler.lr_lambdas = list(map(self, scheduler.lr_lambdas))