[seq2seq testing] multigpu test run via subprocess (#7281)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -170,7 +170,7 @@ class BaseTransformer(pl.LightningModule):
|
||||
self.dataset_size = len(self.test_dataloader().dataset)
|
||||
else:
|
||||
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
|
||||
self.dataset_size = len(self.train_loader.dataset)
|
||||
self.dataset_size = len(self.train_dataloader().dataset)
|
||||
|
||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False):
|
||||
raise NotImplementedError("You must implement this for your task")
|
||||
|
||||
Reference in New Issue
Block a user