update run_squad with tqdm
This commit is contained in:
@@ -18,19 +18,20 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import six
|
||||||
|
import argparse
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import tokenization_pytorch
|
from tqdm import tqdm, trange
|
||||||
import six
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
|
import tokenization_pytorch
|
||||||
from modeling_pytorch import BertConfig, BertForQuestionAnswering
|
from modeling_pytorch import BertConfig, BertForQuestionAnswering
|
||||||
from optimization_pytorch import BERTAdam
|
from optimization_pytorch import BERTAdam
|
||||||
|
|
||||||
@@ -841,9 +842,9 @@ def main():
|
|||||||
|
|
||||||
logger.info("HHHHH Starting Traing")
|
logger.info("HHHHH Starting Traing")
|
||||||
model.train()
|
model.train()
|
||||||
for epoch in range(int(args.num_train_epochs)):
|
for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
|
||||||
#for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
|
for input_ids, input_mask, segment_ids, start_positions, end_positions in tqdm(train_dataloader,
|
||||||
for input_ids, input_mask, segment_ids, start_positions, end_positions in train_dataloader:
|
desc="Iteration"):
|
||||||
input_ids = input_ids.to(device)
|
input_ids = input_ids.to(device)
|
||||||
input_mask = input_mask.float().to(device)
|
input_mask = input_mask.float().to(device)
|
||||||
segment_ids = segment_ids.to(device)
|
segment_ids = segment_ids.to(device)
|
||||||
|
|||||||
Reference in New Issue
Block a user