[lightning_base] fix s2s logging, only make train_loader once (#6404)
This commit is contained in:
@@ -150,15 +150,20 @@ class BaseTransformer(pl.LightningModule):
|
||||
def test_epoch_end(self, outputs):
|
||||
return self.validation_end(outputs)
|
||||
|
||||
def setup(self, step):
|
||||
train_batch_size = self.hparams.train_batch_size
|
||||
dataloader = self.get_dataloader("train", train_batch_size)
|
||||
self.train_loader = dataloader
|
||||
self.total_steps = (
|
||||
(len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.gpus)))
|
||||
// self.hparams.accumulate_grad_batches
|
||||
* float(self.hparams.max_epochs)
|
||||
)
|
||||
@property
|
||||
def total_steps(self) -> int:
|
||||
"""The number of total training steps that will be run. Used for lr scheduler purposes."""
|
||||
num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
|
||||
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
|
||||
dataset_size = len(self.train_loader.dataset)
|
||||
return (dataset_size / effective_batch_size) * self.hparams.max_epochs
|
||||
|
||||
def setup(self, mode):
|
||||
if mode == "fit":
|
||||
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
|
||||
|
||||
def get_dataloader(self, type_path, batch_size, shuffle=False):
|
||||
raise NotImplementedError("You must implement this for your task")
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.train_loader
|
||||
@@ -304,6 +309,13 @@ def add_generic_args(parser, root_dir) -> None:
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
|
||||
)
|
||||
|
||||
|
||||
def generic_train(
|
||||
|
||||
Reference in New Issue
Block a user