s2s: fix LR logging, remove some dead code. (#6205)
This commit is contained in:
@@ -58,7 +58,6 @@ class BaseTransformer(pl.LightningModule):
|
||||
|
||||
self.hparams = hparams
|
||||
self.step_count = 0
|
||||
self.tfmr_ckpts = {}
|
||||
self.output_dir = Path(self.hparams.output_dir)
|
||||
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
|
||||
if config is None:
|
||||
@@ -99,7 +98,7 @@ class BaseTransformer(pl.LightningModule):
|
||||
self.model = self.model_type.from_pretrained(*args, **kwargs)
|
||||
|
||||
def configure_optimizers(self):
|
||||
"Prepare optimizer and schedule (linear warmup and decay)"
|
||||
"""Prepare optimizer and schedule (linear warmup and decay)"""
|
||||
model = self.model
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
@@ -159,11 +158,9 @@ class BaseTransformer(pl.LightningModule):
|
||||
@pl.utilities.rank_zero_only
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
save_path = self.output_dir.joinpath("best_tfmr")
|
||||
save_path.mkdir(exist_ok=True)
|
||||
self.model.config.save_step = self.step_count
|
||||
self.model.save_pretrained(save_path)
|
||||
self.tokenizer.save_pretrained(save_path)
|
||||
self.tfmr_ckpts[self.step_count] = save_path
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
@@ -274,7 +271,6 @@ def add_generic_args(parser, root_dir) -> None:
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user