update run_squad with tqdm

This commit is contained in:
thomwolf
2018-11-03 17:52:44 +01:00
parent cb76c1ddd3
commit f514cbbf30

View File

@@ -18,19 +18,20 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import six
import argparse
import collections import collections
import logging import logging
import json import json
import math import math
import os import os
import tokenization_pytorch from tqdm import tqdm, trange
import six
import argparse
import torch import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
import tokenization_pytorch
from modeling_pytorch import BertConfig, BertForQuestionAnswering from modeling_pytorch import BertConfig, BertForQuestionAnswering
from optimization_pytorch import BERTAdam from optimization_pytorch import BERTAdam
@@ -841,9 +842,9 @@ def main():
logger.info("HHHHH Starting Traing") logger.info("HHHHH Starting Traing")
model.train() model.train()
for epoch in range(int(args.num_train_epochs)): for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
#for input_ids, input_mask, segment_ids, label_ids in train_dataloader: for input_ids, input_mask, segment_ids, start_positions, end_positions in tqdm(train_dataloader,
for input_ids, input_mask, segment_ids, start_positions, end_positions in train_dataloader: desc="Iteration"):
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.float().to(device)
segment_ids = segment_ids.to(device) segment_ids = segment_ids.to(device)