s2s: fix LR logging, remove some dead code. (#6205)
This commit is contained in:
@@ -19,6 +19,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqLoggingCallback(pl.Callback):
|
||||
def on_batch_end(self, trainer, pl_module):
|
||||
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
|
||||
pl_module.logger.log_metrics(lrs)
|
||||
|
||||
@rank_zero_only
|
||||
def _write_logs(
|
||||
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
|
||||
|
||||
Reference in New Issue
Block a user