From 0d8d2285bae1ea021516b07f6878a6b0fb8eeac0 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Sat, 3 Nov 2018 12:23:00 +0100 Subject: [PATCH] fix optimization_test --- optimization_test_pytorch.py | 8 ++++---- requirements.txt | 3 +++ run_classifier_pytorch.py | 5 +++-- 3 files changed, 10 insertions(+), 6 deletions(-) create mode 100644 requirements.txt diff --git a/optimization_test_pytorch.py b/optimization_test_pytorch.py index 5021467d1f..4d6e40352b 100644 --- a/optimization_test_pytorch.py +++ b/optimization_test_pytorch.py @@ -16,10 +16,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import optimization_pytorch as optimization -import torch import unittest +import torch + +import optimization_pytorch as optimization class OptimizationTest(unittest.TestCase): @@ -34,8 +35,7 @@ class OptimizationTest(unittest.TestCase): criterion = torch.nn.MSELoss(reduction='elementwise_mean') optimizer = optimization.BERTAdam(params={w}, lr=0.2, schedule='warmup_linear', warmup=0.1, t_total=100) for _ in range(100): - # TODO Solve: reduction='elementwise_mean'=True not taken into account so division by x.size(0) is necessary - loss = criterion(x, w) / x.size(0) + loss = criterion(w, x) loss.backward() optimizer.step() self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000..b6070041b6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +torch +tqdm +pytest \ No newline at end of file diff --git a/run_classifier_pytorch.py b/run_classifier_pytorch.py index d7b4e2572b..426bff64c9 100644 --- a/run_classifier_pytorch.py +++ b/run_classifier_pytorch.py @@ -24,6 +24,7 @@ import logging import argparse import numpy as np +from tqdm import tqdm, trange import torch from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler @@ -513,8 +514,8 @@ def main(): model.train() nb_tr_examples = 0 - for epoch in range(int(args.num_train_epochs)): - for input_ids, input_mask, segment_ids, label_ids in train_dataloader: + for epoch in trange(args.num_train_epochs, desc="Epoch"): + for input_ids, input_mask, segment_ids, label_ids in tqdm(train_dataloader, desc="Iteration"): input_ids = input_ids.to(device) input_mask = input_mask.float().to(device) segment_ids = segment_ids.to(device)