only on main process

This commit is contained in:
thomwolf
2019-06-18 14:03:46 +02:00
parent 326944d627
commit 335f57baf8

View File

@@ -917,6 +917,7 @@ def main():
model = torch.nn.DataParallel(model)
if args.do_train:
if args.local_rank in [-1, 0]:
writer = SummaryWriter()
# Prepare data loader
train_examples = read_squad_examples(
@@ -1016,6 +1017,7 @@ def main():
else:
loss.backward()
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.local_rank in [-1, 0]:
writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
writer.add_scalar('loss', loss.item(), global_step)
if args.fp16: