fix optimization_test
This commit is contained in:
@@ -16,10 +16,11 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import optimization_pytorch as optimization
|
|
||||||
import torch
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import optimization_pytorch as optimization
|
||||||
|
|
||||||
class OptimizationTest(unittest.TestCase):
|
class OptimizationTest(unittest.TestCase):
|
||||||
|
|
||||||
@@ -34,8 +35,7 @@ class OptimizationTest(unittest.TestCase):
|
|||||||
criterion = torch.nn.MSELoss(reduction='elementwise_mean')
|
criterion = torch.nn.MSELoss(reduction='elementwise_mean')
|
||||||
optimizer = optimization.BERTAdam(params={w}, lr=0.2, schedule='warmup_linear', warmup=0.1, t_total=100)
|
optimizer = optimization.BERTAdam(params={w}, lr=0.2, schedule='warmup_linear', warmup=0.1, t_total=100)
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
# TODO Solve: reduction='elementwise_mean'=True not taken into account so division by x.size(0) is necessary
|
loss = criterion(w, x)
|
||||||
loss = criterion(x, w) / x.size(0)
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
|
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
|
||||||
|
|||||||
3
requirements.txt
Normal file
3
requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
torch
|
||||||
|
tqdm
|
||||||
|
pytest
|
||||||
@@ -24,6 +24,7 @@ import logging
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from tqdm import tqdm, trange
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
@@ -513,8 +514,8 @@ def main():
|
|||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
nb_tr_examples = 0
|
nb_tr_examples = 0
|
||||||
for epoch in range(int(args.num_train_epochs)):
|
for epoch in trange(args.num_train_epochs, desc="Epoch"):
|
||||||
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
|
for input_ids, input_mask, segment_ids, label_ids in tqdm(train_dataloader, desc="Iteration"):
|
||||||
input_ids = input_ids.to(device)
|
input_ids = input_ids.to(device)
|
||||||
input_mask = input_mask.float().to(device)
|
input_mask = input_mask.float().to(device)
|
||||||
segment_ids = segment_ids.to(device)
|
segment_ids = segment_ids.to(device)
|
||||||
|
|||||||
Reference in New Issue
Block a user