only on main process
This commit is contained in:
@@ -917,6 +917,7 @@ def main():
|
|||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
|
if args.local_rank in [-1, 0]:
|
||||||
writer = SummaryWriter()
|
writer = SummaryWriter()
|
||||||
# Prepare data loader
|
# Prepare data loader
|
||||||
train_examples = read_squad_examples(
|
train_examples = read_squad_examples(
|
||||||
@@ -1016,6 +1017,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
|
if args.local_rank in [-1, 0]:
|
||||||
writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
||||||
writer.add_scalar('loss', loss.item(), global_step)
|
writer.add_scalar('loss', loss.item(), global_step)
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
|
|||||||
Reference in New Issue
Block a user