add tensorboard to run_squad
This commit is contained in:
@@ -34,6 +34,8 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
|||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
|
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
|
||||||
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig
|
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
|
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
|
||||||
@@ -915,9 +917,8 @@ def main():
|
|||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
|
writer = SummaryWriter()
|
||||||
# Prepare data loader
|
# Prepare data loader
|
||||||
|
|
||||||
train_examples = read_squad_examples(
|
train_examples = read_squad_examples(
|
||||||
input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
|
input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
|
||||||
cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
|
cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
|
||||||
@@ -999,7 +1000,7 @@ def main():
|
|||||||
logger.info(" Num steps = %d", num_train_optimization_steps)
|
logger.info(" Num steps = %d", num_train_optimization_steps)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
|
for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
|
||||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
|
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
|
||||||
if n_gpu == 1:
|
if n_gpu == 1:
|
||||||
batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
|
batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
|
||||||
@@ -1015,6 +1016,8 @@ def main():
|
|||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
|
writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
||||||
|
writer.add_scalar('loss', loss.item(), global_step)
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
# modify learning rate with special warm up BERT uses
|
# modify learning rate with special warm up BERT uses
|
||||||
# if args.fp16 is False, BertAdam is used and handles this automatically
|
# if args.fp16 is False, BertAdam is used and handles this automatically
|
||||||
|
|||||||
Reference in New Issue
Block a user