Make schedulers picklable by making lr_lambda fns global (#21768)
* Make schedulers picklable by making lr_lambda fns global * add unused _get_constant_schedule_lr_lambda arg * remove unneeded _get_constant_schedule_lr_lamda * add test * make style * rebase, remove torch dep, put lambda back * repo-consistency and style
This commit is contained in:
@@ -166,5 +166,21 @@ class ScheduleInitTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
scheduler = scheduler_func(self.optimizer, **kwargs)
|
||||
if scheduler_func.__name__ != "get_constant_schedule":
|
||||
LambdaScheduleWrapper.wrap_scheduler(scheduler) # wrap to test picklability of the schedule
|
||||
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
|
||||
self.assertListEqual(lrs_1, lrs_2, msg=f"failed for {scheduler_func} in save and reload")
|
||||
|
||||
|
||||
class LambdaScheduleWrapper:
|
||||
"""See https://github.com/huggingface/transformers/issues/21689"""
|
||||
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def wrap_scheduler(self, scheduler):
|
||||
scheduler.lr_lambdas = list(map(self, scheduler.lr_lambdas))
|
||||
|
||||
Reference in New Issue
Block a user