From 7816f7921fd5a21fdc74ca0f29589c74bceed0e2 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 15 Apr 2019 15:27:10 +0200 Subject: [PATCH] clean up distributed training logging in run_squad example --- examples/run_squad.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index 00ee368b14..bad46203bc 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -985,7 +985,7 @@ def main(): model.train() for _ in trange(int(args.num_train_epochs), desc="Epoch"): - for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): + for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])): if n_gpu == 1: batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self input_ids, input_mask, segment_ids, start_positions, end_positions = batch @@ -1058,7 +1058,7 @@ def main(): model.eval() all_results = [] logger.info("Start evaluating") - for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"): + for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating", disable=args.local_rank not in [-1, 0]): if len(all_results) % 1000 == 0: logger.info("Processing example: %d" % (len(all_results))) input_ids = input_ids.to(device)