fix optimization_test
This commit is contained in:
@@ -16,10 +16,11 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import optimization_pytorch as optimization
|
||||
import torch
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
import optimization_pytorch as optimization
|
||||
|
||||
class OptimizationTest(unittest.TestCase):
|
||||
|
||||
@@ -34,8 +35,7 @@ class OptimizationTest(unittest.TestCase):
|
||||
criterion = torch.nn.MSELoss(reduction='elementwise_mean')
|
||||
optimizer = optimization.BERTAdam(params={w}, lr=0.2, schedule='warmup_linear', warmup=0.1, t_total=100)
|
||||
for _ in range(100):
|
||||
# TODO Solve: reduction='elementwise_mean'=True not taken into account so division by x.size(0) is necessary
|
||||
loss = criterion(x, w) / x.size(0)
|
||||
loss = criterion(w, x)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
|
||||
|
||||
Reference in New Issue
Block a user