change to apex for better fp16 and multi-gpu support

This commit is contained in:
Deyu Fu
2018-12-05 15:07:40 -08:00
parent a3a3180c86
commit c8ea286048
6 changed files with 142 additions and 169 deletions

View File

@@ -35,7 +35,7 @@ class OptimizationTest(unittest.TestCase):
criterion = torch.nn.MSELoss(reduction='elementwise_mean')
# No warmup, constant schedule, no gradient clipping
optimizer = BertAdam(params=[w], lr=2e-1,
weight_decay_rate=0.0,
weight_decay=0.0,
max_grad_norm=-1)
for _ in range(100):
loss = criterion(w, target)