From b0167632ce815cbdd256a0c8bfff57639748ea75 Mon Sep 17 00:00:00 2001 From: Cola <43774355+Colanim@users.noreply.github.com> Date: Fri, 24 Apr 2020 20:55:34 +0900 Subject: [PATCH] Shuffle train subset for summarization example (#3909) * Shuffle train subset * Cleaner shuffle --- examples/summarization/bart/finetune.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/summarization/bart/finetune.py b/examples/summarization/bart/finetune.py index 893188e76f..b916ea7544 100644 --- a/examples/summarization/bart/finetune.py +++ b/examples/summarization/bart/finetune.py @@ -102,13 +102,13 @@ class SummarizationTrainer(BaseTransformer): return self.test_end(outputs) - def get_dataloader(self, type_path: str, batch_size: int) -> DataLoader: + def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs) - dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn) + dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=shuffle) return dataloader def train_dataloader(self) -> DataLoader: - dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size) + dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True) t_total = ( (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu))) // self.hparams.gradient_accumulation_steps