From 326944d627ad166f0e6a6921b7168a2caf31dd1e Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 18 Jun 2019 14:02:42 +0200 Subject: [PATCH] add tensorboard to run_squad --- examples/run_squad.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index 313cb453af..775e93e4db 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -34,6 +34,8 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm, trange +from tensorboardX import SummaryWriter + from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule @@ -915,9 +917,8 @@ def main(): model = torch.nn.DataParallel(model) if args.do_train: - + writer = SummaryWriter() # Prepare data loader - train_examples = read_squad_examples( input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative) cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format( @@ -999,7 +1000,7 @@ def main(): logger.info(" Num steps = %d", num_train_optimization_steps) model.train() - for _ in trange(int(args.num_train_epochs), desc="Epoch"): + for epoch in trange(int(args.num_train_epochs), desc="Epoch"): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])): if n_gpu == 1: batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self @@ -1015,6 +1016,8 @@ def main(): else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: + writer.add_scalar('lr', optimizer.get_lr()[0], global_step) + writer.add_scalar('loss', loss.item(), global_step) if args.fp16: # modify learning rate with special warm up BERT uses # if args.fp16 is False, BertAdam is used and handles this automatically