s2s: fix LR logging, remove some dead code. (#6205)
This commit is contained in:
@@ -19,6 +19,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqLoggingCallback(pl.Callback):
|
||||
def on_batch_end(self, trainer, pl_module):
|
||||
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
|
||||
pl_module.logger.log_metrics(lrs)
|
||||
|
||||
@rank_zero_only
|
||||
def _write_logs(
|
||||
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
|
||||
|
||||
@@ -5,7 +5,6 @@ python finetune.py \
|
||||
--learning_rate=3e-5 \
|
||||
--fp16 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--val_check_interval=0.25 \
|
||||
--adam_eps 1e-06 \
|
||||
--num_train_epochs 6 --src_lang en_XX --tgt_lang ro_RO \
|
||||
@@ -15,6 +14,5 @@ python finetune.py \
|
||||
--task translation \
|
||||
--warmup_steps 500 \
|
||||
--freeze_embeds \
|
||||
--early_stopping_patience 4 \
|
||||
--model_name_or_path=facebook/mbart-large-cc25 \
|
||||
$@
|
||||
|
||||
Reference in New Issue
Block a user