fix #726 - get_lr in examples
This commit is contained in:
@@ -313,6 +313,7 @@ def main():
|
|||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
|
if not args.fp16:
|
||||||
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
||||||
tb_writer.add_scalar('loss', loss.item(), global_step)
|
tb_writer.add_scalar('loss', loss.item(), global_step)
|
||||||
|
|
||||||
|
|||||||
@@ -319,6 +319,7 @@ def main():
|
|||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
if args.local_rank in [-1, 0] and (args.log_every <= 0 or (step + 1) % args.log_every == 0):
|
if args.local_rank in [-1, 0] and (args.log_every <= 0 or (step + 1) % args.log_every == 0):
|
||||||
|
if not args.fp16:
|
||||||
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
||||||
tb_writer.add_scalar('loss', loss.item(), global_step)
|
tb_writer.add_scalar('loss', loss.item(), global_step)
|
||||||
|
|
||||||
|
|||||||
@@ -313,6 +313,7 @@ def main():
|
|||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
|
if not args.fp16:
|
||||||
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
||||||
tb_writer.add_scalar('loss', loss.item(), global_step)
|
tb_writer.add_scalar('loss', loss.item(), global_step)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user