only on main process
This commit is contained in:
@@ -917,6 +917,7 @@ def main():
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
if args.do_train:
|
||||
if args.local_rank in [-1, 0]:
|
||||
writer = SummaryWriter()
|
||||
# Prepare data loader
|
||||
train_examples = read_squad_examples(
|
||||
@@ -1016,6 +1017,7 @@ def main():
|
||||
else:
|
||||
loss.backward()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.local_rank in [-1, 0]:
|
||||
writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
||||
writer.add_scalar('loss', loss.item(), global_step)
|
||||
if args.fp16:
|
||||
|
||||
Reference in New Issue
Block a user