Shuffle train subset for summarization example (#3909)

* Shuffle train subset

* Cleaner shuffle
This commit is contained in:
Cola
2020-04-24 20:55:34 +09:00
committed by GitHub
parent c53cc018de
commit b0167632ce

View File

@@ -102,13 +102,13 @@ class SummarizationTrainer(BaseTransformer):
return self.test_end(outputs)
def get_dataloader(self, type_path: str, batch_size: int) -> DataLoader:
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=shuffle)
return dataloader
def train_dataloader(self) -> DataLoader:
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size)
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu)))
// self.hparams.gradient_accumulation_steps