[lightning_base] fix s2s logging, only make train_loader once (#6404)

This commit is contained in:
Sam Shleifer
2020-08-16 22:49:41 -04:00
committed by GitHub
parent 72add6c98f
commit 84c265ffcc
6 changed files with 47 additions and 72 deletions

View File

@@ -150,15 +150,20 @@ class BaseTransformer(pl.LightningModule):
def test_epoch_end(self, outputs):
return self.validation_end(outputs)
def setup(self, step):
train_batch_size = self.hparams.train_batch_size
dataloader = self.get_dataloader("train", train_batch_size)
self.train_loader = dataloader
self.total_steps = (
(len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.accumulate_grad_batches
* float(self.hparams.max_epochs)
)
@property
def total_steps(self) -> int:
"""The number of total training steps that will be run. Used for lr scheduler purposes."""
num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
dataset_size = len(self.train_loader.dataset)
return (dataset_size / effective_batch_size) * self.hparams.max_epochs
def setup(self, mode):
if mode == "fit":
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
def get_dataloader(self, type_path, batch_size, shuffle=False):
raise NotImplementedError("You must implement this for your task")
def train_dataloader(self):
return self.train_loader
@@ -304,6 +309,13 @@ def add_generic_args(parser, root_dir) -> None:
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
)
def generic_train(