fix tests

This commit is contained in:
thomwolf
2018-11-17 11:58:14 +01:00
parent a99b971738
commit 757750d6f6
9 changed files with 48 additions and 61 deletions

View File

@@ -20,7 +20,7 @@ import unittest
import torch
import optimization
from pytorch_pretrained_bert import BertAdam
class OptimizationTest(unittest.TestCase):
@@ -34,7 +34,7 @@ class OptimizationTest(unittest.TestCase):
target = torch.tensor([0.4, 0.2, -0.5])
criterion = torch.nn.MSELoss(reduction='elementwise_mean')
# No warmup, constant schedule, no gradient clipping
optimizer = optimization.BERTAdam(params=[w], lr=2e-1,
optimizer = BertAdam(params=[w], lr=2e-1,
weight_decay_rate=0.0,
max_grad_norm=-1)
for _ in range(100):