From 16a1f338c4b0f90e3476a777b06c438e21b87b37 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 18 Jun 2019 17:06:31 +0200 Subject: [PATCH] fixing --- examples/run_classifier.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 63cc2b4b9c..7c00e4833d 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -222,6 +222,10 @@ def main(): elif n_gpu > 1: model = torch.nn.DataParallel(model) + global_step = 0 + nb_tr_steps = 0 + tr_loss = 0 + if args.do_train: if args.local_rank in [-1, 0]: tb_writer = SummaryWriter() @@ -293,10 +297,6 @@ def main(): warmup=args.warmup_proportion, t_total=num_train_optimization_steps) - global_step = 0 - nb_tr_steps = 0 - tr_loss = 0 - logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_examples)) logger.info(" Batch size = %d", args.train_batch_size)