Upgrade PyTorch Lightning to 1.0.2 (#7852)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Sean Naren
2020-10-28 18:59:14 +00:00
committed by GitHub
parent 1b6c8d4811
commit 5e24982e58
8 changed files with 11 additions and 13 deletions

View File

@@ -182,7 +182,6 @@ class SummarizationModule(BaseTransformer):
return self._generative_step(batch)
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
self.step_count += 1
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
loss = losses["loss"]
@@ -252,7 +251,7 @@ class SummarizationModule(BaseTransformer):
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
dataset = self.get_dataset(type_path)
if self.hparams.sortish_sampler and type_path != "test":
if self.hparams.sortish_sampler and type_path != "test" and type_path != "val":
sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
return DataLoader(
dataset,
@@ -263,7 +262,7 @@ class SummarizationModule(BaseTransformer):
sampler=sampler,
)
elif self.hparams.max_tokens_per_batch is not None and type_path != "test":
elif self.hparams.max_tokens_per_batch is not None and type_path != "test" and type_path != "val":
batch_sampler = dataset.make_dynamic_sampler(
self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
)