From 59cefd4f985b7221846189690ead3300ff864b3d Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 26 Jun 2019 11:28:27 +0200 Subject: [PATCH] fix #726 - get_lr in examples --- examples/run_bert_squad.py | 3 ++- examples/run_xlnet_classifier.py | 3 ++- examples/run_xlnet_squad.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/run_bert_squad.py b/examples/run_bert_squad.py index b35a9175ec..9aaa711c2b 100644 --- a/examples/run_bert_squad.py +++ b/examples/run_bert_squad.py @@ -313,7 +313,8 @@ def main(): optimizer.zero_grad() global_step += 1 if args.local_rank in [-1, 0]: - tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step) + if not args.fp16: + tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step) tb_writer.add_scalar('loss', loss.item(), global_step) if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): diff --git a/examples/run_xlnet_classifier.py b/examples/run_xlnet_classifier.py index 0278b40cdd..2309815981 100644 --- a/examples/run_xlnet_classifier.py +++ b/examples/run_xlnet_classifier.py @@ -319,7 +319,8 @@ def main(): optimizer.zero_grad() global_step += 1 if args.local_rank in [-1, 0] and (args.log_every <= 0 or (step + 1) % args.log_every == 0): - tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step) + if not args.fp16: + tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step) tb_writer.add_scalar('loss', loss.item(), global_step) ### Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() diff --git a/examples/run_xlnet_squad.py b/examples/run_xlnet_squad.py index a72d648ff7..927668c57a 100644 --- a/examples/run_xlnet_squad.py +++ b/examples/run_xlnet_squad.py @@ -313,7 +313,8 @@ def main(): optimizer.zero_grad() global_step += 1 if args.local_rank in [-1, 0]: - tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step) + if not args.fp16: + tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step) tb_writer.add_scalar('loss', loss.item(), global_step) if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):