From 0203d6517fb510ff05cddb65168c07655c3c8168 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 11 Aug 2020 13:27:11 -0700 Subject: [PATCH] [pl] restore lr logging behavior for glue, ner examples (#6314) --- examples/lightning_base.py | 3 ++- examples/text-classification/run_pl_glue.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/lightning_base.py b/examples/lightning_base.py index e2532a28ea..11a4ded828 100644 --- a/examples/lightning_base.py +++ b/examples/lightning_base.py @@ -245,7 +245,8 @@ class BaseTransformer(pl.LightningModule): class LoggingCallback(pl.Callback): def on_batch_end(self, trainer, pl_module): - lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)} + lr_scheduler = trainer.lr_schedulers[0]["scheduler"] + lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())} pl_module.logger.log_metrics(lrs) def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): diff --git a/examples/text-classification/run_pl_glue.py b/examples/text-classification/run_pl_glue.py index 459c7324a0..36c1825972 100644 --- a/examples/text-classification/run_pl_glue.py +++ b/examples/text-classification/run_pl_glue.py @@ -44,8 +44,8 @@ class GLUETransformer(BaseTransformer): outputs = self(**inputs) loss = outputs[0] - # tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]} - tensorboard_logs = {"loss": loss} + lr_scheduler = self.trainer.lr_schedulers[0]["scheduler"] + tensorboard_logs = {"loss": loss, "rate": lr_scheduler.get_last_lr()[-1]} return {"loss": loss, "log": tensorboard_logs} def prepare_data(self):