From 0b7a20c651afd014e884ba2c1a486fa971689f19 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Sun, 4 Nov 2018 11:07:34 +0100 Subject: [PATCH] add tqdm, clean up logging --- modeling.py | 2 -- run_squad.py | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/modeling.py b/modeling.py index b08a22966c..db9ca6e0d8 100644 --- a/modeling.py +++ b/modeling.py @@ -435,7 +435,6 @@ class BertForSequenceClassification(nn.Module): def init_weights(m): if isinstance(m, (nn.Linear, nn.Embedding)): - print("Initializing {}".format(m)) # Slight difference here with the TF version which uses truncated_normal # cf https://github.com/pytorch/pytorch/pull/5617 m.weight.data.normal_(config.initializer_range) @@ -481,7 +480,6 @@ class BertForQuestionAnswering(nn.Module): def init_weights(m): if isinstance(m, (nn.Linear, nn.Embedding)): - print("Initializing {}".format(m)) # Slight difference here with the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 m.weight.data.normal_(config.initializer_range) diff --git a/run_squad.py b/run_squad.py index 46c689496c..53334a68d0 100644 --- a/run_squad.py +++ b/run_squad.py @@ -912,9 +912,9 @@ def main(): model.eval() all_results = [] - logger.info("Start evaulating") + logger.info("Start evaluating") #for input_ids, input_mask, segment_ids, label_ids, example_index in eval_dataloader: - for input_ids, input_mask, segment_ids, example_index in eval_dataloader: + for input_ids, input_mask, segment_ids, example_index in tqdm(eval_dataloader, descr="Evaluating"): if len(all_results) % 1000 == 0: logger.info("Processing example: %d" % (len(all_results)))