[lightning_base] fix s2s logging, only make train_loader once (#6404)
This commit is contained in:
@@ -10,14 +10,7 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from lightning_base import generic_train
|
||||
from transformers import (
|
||||
AdamW,
|
||||
BartConfig,
|
||||
BartForConditionalGeneration,
|
||||
MBartTokenizer,
|
||||
T5Config,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from transformers import BartConfig, BartForConditionalGeneration, MBartTokenizer, T5Config, T5ForConditionalGeneration
|
||||
|
||||
|
||||
try:
|
||||
@@ -158,24 +151,6 @@ class BartSummarizationDistiller(SummarizationModule):
|
||||
)
|
||||
return loss_ce, s_logits_slct, t_logits_slct
|
||||
|
||||
def configure_optimizers(self):
|
||||
"Prepare optimizer and schedule (linear warmup and decay)"
|
||||
model = self.model
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": self.hparams.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
|
||||
self.opt = optimizer
|
||||
return [optimizer]
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
SummarizationModule.add_model_specific_args(parser, root_dir)
|
||||
|
||||
Reference in New Issue
Block a user