fix #726 - get_lr in examples

This commit is contained in:
thomwolf
2019-06-26 11:28:27 +02:00
parent ddc2cc61a6
commit 59cefd4f98
3 changed files with 6 additions and 3 deletions

View File

@@ -313,6 +313,7 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
if not args.fp16:
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step) tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
tb_writer.add_scalar('loss', loss.item(), global_step) tb_writer.add_scalar('loss', loss.item(), global_step)

View File

@@ -319,6 +319,7 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
if args.local_rank in [-1, 0] and (args.log_every <= 0 or (step + 1) % args.log_every == 0): if args.local_rank in [-1, 0] and (args.log_every <= 0 or (step + 1) % args.log_every == 0):
if not args.fp16:
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step) tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
tb_writer.add_scalar('loss', loss.item(), global_step) tb_writer.add_scalar('loss', loss.item(), global_step)

View File

@@ -313,6 +313,7 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
if not args.fp16:
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step) tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
tb_writer.add_scalar('loss', loss.item(), global_step) tb_writer.add_scalar('loss', loss.item(), global_step)