Shuffle train subset for summarization example (#3909)
* Shuffle train subset * Cleaner shuffle
This commit is contained in:
@@ -102,13 +102,13 @@ class SummarizationTrainer(BaseTransformer):
|
||||
|
||||
return self.test_end(outputs)
|
||||
|
||||
def get_dataloader(self, type_path: str, batch_size: int) -> DataLoader:
|
||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=shuffle)
|
||||
return dataloader
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size)
|
||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||
t_total = (
|
||||
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu)))
|
||||
// self.hparams.gradient_accumulation_steps
|
||||
|
||||
Reference in New Issue
Block a user