Upgrade PyTorch Lightning to 1.0.2 (#7852)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -182,7 +182,6 @@ class SummarizationModule(BaseTransformer):
|
||||
return self._generative_step(batch)
|
||||
|
||||
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
|
||||
|
||||
self.step_count += 1
|
||||
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
||||
loss = losses["loss"]
|
||||
@@ -252,7 +251,7 @@ class SummarizationModule(BaseTransformer):
|
||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||
dataset = self.get_dataset(type_path)
|
||||
|
||||
if self.hparams.sortish_sampler and type_path != "test":
|
||||
if self.hparams.sortish_sampler and type_path != "test" and type_path != "val":
|
||||
sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
@@ -263,7 +262,7 @@ class SummarizationModule(BaseTransformer):
|
||||
sampler=sampler,
|
||||
)
|
||||
|
||||
elif self.hparams.max_tokens_per_batch is not None and type_path != "test":
|
||||
elif self.hparams.max_tokens_per_batch is not None and type_path != "test" and type_path != "val":
|
||||
batch_sampler = dataset.make_dynamic_sampler(
|
||||
self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user