fix optimization_test

This commit is contained in:
thomwolf
2018-11-03 12:23:00 +01:00
parent 45efc9d807
commit 0d8d2285ba
3 changed files with 10 additions and 6 deletions

View File

@@ -16,10 +16,11 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import optimization_pytorch as optimization
import torch
import unittest import unittest
import torch
import optimization_pytorch as optimization
class OptimizationTest(unittest.TestCase): class OptimizationTest(unittest.TestCase):
@@ -34,8 +35,7 @@ class OptimizationTest(unittest.TestCase):
criterion = torch.nn.MSELoss(reduction='elementwise_mean') criterion = torch.nn.MSELoss(reduction='elementwise_mean')
optimizer = optimization.BERTAdam(params={w}, lr=0.2, schedule='warmup_linear', warmup=0.1, t_total=100) optimizer = optimization.BERTAdam(params={w}, lr=0.2, schedule='warmup_linear', warmup=0.1, t_total=100)
for _ in range(100): for _ in range(100):
# TODO Solve: reduction='elementwise_mean'=True not taken into account so division by x.size(0) is necessary loss = criterion(w, x)
loss = criterion(x, w) / x.size(0)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)

3
requirements.txt Normal file
View File

@@ -0,0 +1,3 @@
torch
tqdm
pytest

View File

@@ -24,6 +24,7 @@ import logging
import argparse import argparse
import numpy as np import numpy as np
from tqdm import tqdm, trange
import torch import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
@@ -513,8 +514,8 @@ def main():
model.train() model.train()
nb_tr_examples = 0 nb_tr_examples = 0
for epoch in range(int(args.num_train_epochs)): for epoch in trange(args.num_train_epochs, desc="Epoch"):
for input_ids, input_mask, segment_ids, label_ids in train_dataloader: for input_ids, input_mask, segment_ids, label_ids in tqdm(train_dataloader, desc="Iteration"):
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.float().to(device)
segment_ids = segment_ids.to(device) segment_ids = segment_ids.to(device)