fix optimization_test
This commit is contained in:
@@ -24,6 +24,7 @@ import logging
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm, trange
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
@@ -513,8 +514,8 @@ def main():
|
||||
|
||||
model.train()
|
||||
nb_tr_examples = 0
|
||||
for epoch in range(int(args.num_train_epochs)):
|
||||
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
|
||||
for epoch in trange(args.num_train_epochs, desc="Epoch"):
|
||||
for input_ids, input_mask, segment_ids, label_ids in tqdm(train_dataloader, desc="Iteration"):
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.float().to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
|
||||
Reference in New Issue
Block a user