fix tests
This commit is contained in:
@@ -20,7 +20,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
import optimization
|
||||
from pytorch_pretrained_bert import BertAdam
|
||||
|
||||
class OptimizationTest(unittest.TestCase):
|
||||
|
||||
@@ -34,7 +34,7 @@ class OptimizationTest(unittest.TestCase):
|
||||
target = torch.tensor([0.4, 0.2, -0.5])
|
||||
criterion = torch.nn.MSELoss(reduction='elementwise_mean')
|
||||
# No warmup, constant schedule, no gradient clipping
|
||||
optimizer = optimization.BERTAdam(params=[w], lr=2e-1,
|
||||
optimizer = BertAdam(params=[w], lr=2e-1,
|
||||
weight_decay_rate=0.0,
|
||||
max_grad_norm=-1)
|
||||
for _ in range(100):
|
||||
|
||||
Reference in New Issue
Block a user