[examples] bump pl=0.9.0 (#7053)

This commit is contained in:
Sam Shleifer
2020-10-11 16:39:38 -04:00
committed by GitHub
parent ba4bbd92bc
commit 827c519494
7 changed files with 27 additions and 42 deletions

View File

@@ -119,7 +119,7 @@ class BaseTransformer(pl.LightningModule):
def get_lr_scheduler(self):
get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler]
scheduler = get_schedule_func(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps()
)
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
return scheduler
@@ -159,19 +159,20 @@ class BaseTransformer(pl.LightningModule):
def test_epoch_end(self, outputs):
return self.validation_end(outputs)
@property
def total_steps(self) -> int:
"""The number of total training steps that will be run. Used for lr scheduler purposes."""
num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
dataset_size = len(self.train_loader.dataset)
return (dataset_size / effective_batch_size) * self.hparams.max_epochs
return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs
def setup(self, mode):
if mode == "fit":
if mode == "test":
self.dataset_size = len(self.test_dataloader().dataset)
else:
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
self.dataset_size = len(self.train_loader.dataset)
def get_dataloader(self, type_path, batch_size, shuffle=False):
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False):
raise NotImplementedError("You must implement this for your task")
def train_dataloader(self):