s2s: fix LR logging, remove some dead code. (#6205)

This commit is contained in:
Sam Shleifer
2020-08-03 10:36:26 -04:00
committed by GitHub
parent 06f1692b02
commit b6b2f2270f
3 changed files with 5 additions and 7 deletions

View File

@@ -19,6 +19,10 @@ logger = logging.getLogger(__name__)
class Seq2SeqLoggingCallback(pl.Callback):
def on_batch_end(self, trainer, pl_module):
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
pl_module.logger.log_metrics(lrs)
@rank_zero_only
def _write_logs(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True