fixing optimization

This commit is contained in:
thomwolf
2018-11-03 17:38:15 +01:00
parent 852e4b3c00
commit 088ad45888
4 changed files with 85 additions and 49 deletions

View File

@@ -31,13 +31,18 @@ class OptimizationTest(unittest.TestCase):
def test_adam(self):
w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
x = torch.tensor([0.4, 0.2, -0.5])
target = torch.tensor([0.4, 0.2, -0.5])
criterion = torch.nn.MSELoss(reduction='elementwise_mean')
optimizer = optimization.BERTAdam(params={w}, lr=0.2, schedule='warmup_linear', warmup=0.1, t_total=100)
# No warmup, constant schedule, no gradient clipping
optimizer = optimization.BERTAdam(params=[w], lr=2e-1,
weight_decay_rate=0.0,
max_grad_norm=-1)
for _ in range(100):
loss = criterion(w, x)
loss = criterion(w, target)
loss.backward()
optimizer.step()
w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves.
w.grad.zero_()
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)