fix optimization_test

This commit is contained in:
thomwolf
2018-11-03 12:23:00 +01:00
parent 45efc9d807
commit 0d8d2285ba
3 changed files with 10 additions and 6 deletions

View File

@@ -24,6 +24,7 @@ import logging
import argparse
import numpy as np
from tqdm import tqdm, trange
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
@@ -513,8 +514,8 @@ def main():
model.train()
nb_tr_examples = 0
for epoch in range(int(args.num_train_epochs)):
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
for epoch in trange(args.num_train_epochs, desc="Epoch"):
for input_ids, input_mask, segment_ids, label_ids in tqdm(train_dataloader, desc="Iteration"):
input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device)
segment_ids = segment_ids.to(device)