updating schedules for state_dict saving
This commit is contained in:
@@ -17,13 +17,14 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
|
||||
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
|
||||
|
||||
import numpy as np
|
||||
from .tokenization_tests_commons import TemporaryDirectory
|
||||
|
||||
|
||||
def unwrap_schedule(scheduler, num_steps=10):
|
||||
@@ -33,6 +34,20 @@ def unwrap_schedule(scheduler, num_steps=10):
|
||||
lrs.append(scheduler.get_lr())
|
||||
return lrs
|
||||
|
||||
def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
|
||||
lrs = []
|
||||
for step in range(num_steps):
|
||||
scheduler.step()
|
||||
lrs.append(scheduler.get_lr())
|
||||
if step == num_steps // 2:
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
file_name = os.path.join(tmpdirname, 'schedule.bin')
|
||||
torch.save(scheduler.state_dict(), file_name)
|
||||
|
||||
state_dict = torch.load(file_name)
|
||||
scheduler.load_state_dict(state_dict)
|
||||
return lrs
|
||||
|
||||
class OptimizationTest(unittest.TestCase):
|
||||
|
||||
def assertListAlmostEqual(self, list1, list2, tol):
|
||||
@@ -72,6 +87,10 @@ class ScheduleInitTest(unittest.TestCase):
|
||||
self.assertEqual(len(lrs[0]), 1)
|
||||
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
|
||||
|
||||
scheduler = ConstantLRSchedule(self.optimizer)
|
||||
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
|
||||
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
|
||||
|
||||
def test_warmup_constant_scheduler(self):
|
||||
scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4)
|
||||
lrs = unwrap_schedule(scheduler, self.num_steps)
|
||||
@@ -79,6 +98,10 @@ class ScheduleInitTest(unittest.TestCase):
|
||||
self.assertEqual(len(lrs[0]), 1)
|
||||
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
|
||||
|
||||
scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4)
|
||||
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
|
||||
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
|
||||
|
||||
def test_warmup_linear_scheduler(self):
|
||||
scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10)
|
||||
lrs = unwrap_schedule(scheduler, self.num_steps)
|
||||
@@ -86,6 +109,10 @@ class ScheduleInitTest(unittest.TestCase):
|
||||
self.assertEqual(len(lrs[0]), 1)
|
||||
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
|
||||
|
||||
scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10)
|
||||
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
|
||||
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
|
||||
|
||||
def test_warmup_cosine_scheduler(self):
|
||||
scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10)
|
||||
lrs = unwrap_schedule(scheduler, self.num_steps)
|
||||
@@ -93,6 +120,10 @@ class ScheduleInitTest(unittest.TestCase):
|
||||
self.assertEqual(len(lrs[0]), 1)
|
||||
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
|
||||
|
||||
scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10)
|
||||
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
|
||||
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
|
||||
|
||||
def test_warmup_cosine_hard_restart_scheduler(self):
|
||||
scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10)
|
||||
lrs = unwrap_schedule(scheduler, self.num_steps)
|
||||
@@ -100,6 +131,9 @@ class ScheduleInitTest(unittest.TestCase):
|
||||
self.assertEqual(len(lrs[0]), 1)
|
||||
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
|
||||
|
||||
scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10)
|
||||
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
|
||||
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user