rewamp optimization
This commit is contained in:
@@ -20,10 +20,9 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import BertAdam
|
||||
from pytorch_transformers import OpenAIAdam
|
||||
from pytorch_transformers.optimization import ConstantLR, WarmupLinearSchedule, WarmupConstantSchedule, \
|
||||
WarmupCosineWithWarmupRestartsSchedule, WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule
|
||||
from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
|
||||
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -34,12 +33,12 @@ class OptimizationTest(unittest.TestCase):
|
||||
for a, b in zip(list1, list2):
|
||||
self.assertAlmostEqual(a, b, delta=tol)
|
||||
|
||||
def test_adam(self):
|
||||
def test_adam_w(self):
|
||||
w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
|
||||
target = torch.tensor([0.4, 0.2, -0.5])
|
||||
criterion = torch.nn.MSELoss()
|
||||
# No warmup, constant schedule, no gradient clipping
|
||||
optimizer = BertAdam(params=[w], lr=2e-1,
|
||||
optimizer = AdamW(params=[w], lr=2e-1,
|
||||
weight_decay=0.0,
|
||||
max_grad_norm=-1)
|
||||
for _ in range(100):
|
||||
@@ -52,23 +51,13 @@ class OptimizationTest(unittest.TestCase):
|
||||
|
||||
|
||||
class ScheduleInitTest(unittest.TestCase):
|
||||
def test_bert_sched_init(self):
|
||||
def test_sched_init(self):
|
||||
m = torch.nn.Linear(50, 50)
|
||||
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
|
||||
optim = AdamW(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
|
||||
optim = AdamW(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = BertAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
|
||||
# shouldn't fail
|
||||
|
||||
def test_openai_sched_init(self):
|
||||
m = torch.nn.Linear(50, 50)
|
||||
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
|
||||
optim = AdamW(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
|
||||
# shouldn't fail
|
||||
|
||||
|
||||
Reference in New Issue
Block a user