update run_squad with tqdm
This commit is contained in:
@@ -18,19 +18,20 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
import argparse
|
||||
import collections
|
||||
import logging
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import tokenization_pytorch
|
||||
import six
|
||||
import argparse
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
import tokenization_pytorch
|
||||
from modeling_pytorch import BertConfig, BertForQuestionAnswering
|
||||
from optimization_pytorch import BERTAdam
|
||||
|
||||
@@ -841,9 +842,9 @@ def main():
|
||||
|
||||
logger.info("HHHHH Starting Traing")
|
||||
model.train()
|
||||
for epoch in range(int(args.num_train_epochs)):
|
||||
#for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
|
||||
for input_ids, input_mask, segment_ids, start_positions, end_positions in train_dataloader:
|
||||
for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
|
||||
for input_ids, input_mask, segment_ids, start_positions, end_positions in tqdm(train_dataloader,
|
||||
desc="Iteration"):
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.float().to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
|
||||
Reference in New Issue
Block a user