diff --git a/tests/optimization_test.py b/tests/optimization_test.py index ad13c28d0c..848b9d1cf5 100644 --- a/tests/optimization_test.py +++ b/tests/optimization_test.py @@ -32,7 +32,7 @@ class OptimizationTest(unittest.TestCase): def test_adam(self): w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True) target = torch.tensor([0.4, 0.2, -0.5]) - criterion = torch.nn.MSELoss(reduction='elementwise_mean') + criterion = torch.nn.MSELoss() # No warmup, constant schedule, no gradient clipping optimizer = BertAdam(params=[w], lr=2e-1, weight_decay=0.0,