From b6b2f2270fe6c32852fc1b887afe354b7b79d18c Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 3 Aug 2020 10:36:26 -0400 Subject: [PATCH] s2s: fix LR logging, remove some dead code. (#6205) --- examples/lightning_base.py | 6 +----- examples/seq2seq/callbacks.py | 4 ++++ examples/seq2seq/train_mbart_cc25_enro.sh | 2 -- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/examples/lightning_base.py b/examples/lightning_base.py index 754538e792..ae03e29561 100644 --- a/examples/lightning_base.py +++ b/examples/lightning_base.py @@ -58,7 +58,6 @@ class BaseTransformer(pl.LightningModule): self.hparams = hparams self.step_count = 0 - self.tfmr_ckpts = {} self.output_dir = Path(self.hparams.output_dir) cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None if config is None: @@ -99,7 +98,7 @@ class BaseTransformer(pl.LightningModule): self.model = self.model_type.from_pretrained(*args, **kwargs) def configure_optimizers(self): - "Prepare optimizer and schedule (linear warmup and decay)" + """Prepare optimizer and schedule (linear warmup and decay)""" model = self.model no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ @@ -159,11 +158,9 @@ class BaseTransformer(pl.LightningModule): @pl.utilities.rank_zero_only def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: save_path = self.output_dir.joinpath("best_tfmr") - save_path.mkdir(exist_ok=True) self.model.config.save_step = self.step_count self.model.save_pretrained(save_path) self.tokenizer.save_pretrained(save_path) - self.tfmr_ckpts[self.step_count] = save_path @staticmethod def add_model_specific_args(parser, root_dir): @@ -274,7 +271,6 @@ def add_generic_args(parser, root_dir) -> None: default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) - parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") diff --git a/examples/seq2seq/callbacks.py b/examples/seq2seq/callbacks.py index 1de3aa5d46..68e06a5f48 100644 --- a/examples/seq2seq/callbacks.py +++ b/examples/seq2seq/callbacks.py @@ -19,6 +19,10 @@ logger = logging.getLogger(__name__) class Seq2SeqLoggingCallback(pl.Callback): + def on_batch_end(self, trainer, pl_module): + lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)} + pl_module.logger.log_metrics(lrs) + @rank_zero_only def _write_logs( self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True diff --git a/examples/seq2seq/train_mbart_cc25_enro.sh b/examples/seq2seq/train_mbart_cc25_enro.sh index 4ec18de369..b8122aee3f 100755 --- a/examples/seq2seq/train_mbart_cc25_enro.sh +++ b/examples/seq2seq/train_mbart_cc25_enro.sh @@ -5,7 +5,6 @@ python finetune.py \ --learning_rate=3e-5 \ --fp16 \ --do_train \ - --do_predict \ --val_check_interval=0.25 \ --adam_eps 1e-06 \ --num_train_epochs 6 --src_lang en_XX --tgt_lang ro_RO \ @@ -15,6 +14,5 @@ python finetune.py \ --task translation \ --warmup_steps 500 \ --freeze_embeds \ - --early_stopping_patience 4 \ --model_name_or_path=facebook/mbart-large-cc25 \ $@