[seq2seq testing] multigpu test run via subprocess (#7281)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Stas Bekman
2020-10-21 14:20:53 -07:00
committed by GitHub
parent f8d3695e8c
commit 8b38173398
6 changed files with 294 additions and 15 deletions

View File

@@ -170,7 +170,7 @@ class BaseTransformer(pl.LightningModule):
self.dataset_size = len(self.test_dataloader().dataset)
else:
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
self.dataset_size = len(self.train_loader.dataset)
self.dataset_size = len(self.train_dataloader().dataset)
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False):
raise NotImplementedError("You must implement this for your task")