From f514cbbf301fea40e15590059a27509e9f0e22a2 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Sat, 3 Nov 2018 17:52:44 +0100 Subject: [PATCH] update run_squad with tqdm --- run_squad_pytorch.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/run_squad_pytorch.py b/run_squad_pytorch.py index d5b771b91a..9b53bcbc5e 100644 --- a/run_squad_pytorch.py +++ b/run_squad_pytorch.py @@ -18,19 +18,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import six +import argparse import collections import logging import json import math import os -import tokenization_pytorch -import six -import argparse +from tqdm import tqdm, trange import torch from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler +import tokenization_pytorch from modeling_pytorch import BertConfig, BertForQuestionAnswering from optimization_pytorch import BERTAdam @@ -841,9 +842,9 @@ def main(): logger.info("HHHHH Starting Traing") model.train() - for epoch in range(int(args.num_train_epochs)): - #for input_ids, input_mask, segment_ids, label_ids in train_dataloader: - for input_ids, input_mask, segment_ids, start_positions, end_positions in train_dataloader: + for epoch in trange(int(args.num_train_epochs), desc="Epoch"): + for input_ids, input_mask, segment_ids, start_positions, end_positions in tqdm(train_dataloader, + desc="Iteration"): input_ids = input_ids.to(device) input_mask = input_mask.float().to(device) segment_ids = segment_ids.to(device)