s2s: fix LR logging, remove some dead code. (#6205)
This commit is contained in:
@@ -58,7 +58,6 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
|
|
||||||
self.hparams = hparams
|
self.hparams = hparams
|
||||||
self.step_count = 0
|
self.step_count = 0
|
||||||
self.tfmr_ckpts = {}
|
|
||||||
self.output_dir = Path(self.hparams.output_dir)
|
self.output_dir = Path(self.hparams.output_dir)
|
||||||
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
|
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
|
||||||
if config is None:
|
if config is None:
|
||||||
@@ -99,7 +98,7 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
self.model = self.model_type.from_pretrained(*args, **kwargs)
|
self.model = self.model_type.from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
"Prepare optimizer and schedule (linear warmup and decay)"
|
"""Prepare optimizer and schedule (linear warmup and decay)"""
|
||||||
model = self.model
|
model = self.model
|
||||||
no_decay = ["bias", "LayerNorm.weight"]
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
@@ -159,11 +158,9 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
@pl.utilities.rank_zero_only
|
@pl.utilities.rank_zero_only
|
||||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||||
save_path = self.output_dir.joinpath("best_tfmr")
|
save_path = self.output_dir.joinpath("best_tfmr")
|
||||||
save_path.mkdir(exist_ok=True)
|
|
||||||
self.model.config.save_step = self.step_count
|
self.model.config.save_step = self.step_count
|
||||||
self.model.save_pretrained(save_path)
|
self.model.save_pretrained(save_path)
|
||||||
self.tokenizer.save_pretrained(save_path)
|
self.tokenizer.save_pretrained(save_path)
|
||||||
self.tfmr_ckpts[self.step_count] = save_path
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_model_specific_args(parser, root_dir):
|
def add_model_specific_args(parser, root_dir):
|
||||||
@@ -274,7 +271,6 @@ def add_generic_args(parser, root_dir) -> None:
|
|||||||
default=1,
|
default=1,
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,10 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class Seq2SeqLoggingCallback(pl.Callback):
|
class Seq2SeqLoggingCallback(pl.Callback):
|
||||||
|
def on_batch_end(self, trainer, pl_module):
|
||||||
|
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
|
||||||
|
pl_module.logger.log_metrics(lrs)
|
||||||
|
|
||||||
@rank_zero_only
|
@rank_zero_only
|
||||||
def _write_logs(
|
def _write_logs(
|
||||||
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
|
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ python finetune.py \
|
|||||||
--learning_rate=3e-5 \
|
--learning_rate=3e-5 \
|
||||||
--fp16 \
|
--fp16 \
|
||||||
--do_train \
|
--do_train \
|
||||||
--do_predict \
|
|
||||||
--val_check_interval=0.25 \
|
--val_check_interval=0.25 \
|
||||||
--adam_eps 1e-06 \
|
--adam_eps 1e-06 \
|
||||||
--num_train_epochs 6 --src_lang en_XX --tgt_lang ro_RO \
|
--num_train_epochs 6 --src_lang en_XX --tgt_lang ro_RO \
|
||||||
@@ -15,6 +14,5 @@ python finetune.py \
|
|||||||
--task translation \
|
--task translation \
|
||||||
--warmup_steps 500 \
|
--warmup_steps 500 \
|
||||||
--freeze_embeds \
|
--freeze_embeds \
|
||||||
--early_stopping_patience 4 \
|
|
||||||
--model_name_or_path=facebook/mbart-large-cc25 \
|
--model_name_or_path=facebook/mbart-large-cc25 \
|
||||||
$@
|
$@
|
||||||
|
|||||||
Reference in New Issue
Block a user