Add num workers cli arg (#7322)
* Add dataloader_num_workers to TrainingArguments This argument is meant to be used to set the number of workers for the PyTorch DataLoader. * Pass num_workers argument on DataLoader init
This commit is contained in:
@@ -352,6 +352,7 @@ class Trainer:
|
||||
sampler=train_sampler,
|
||||
collate_fn=self.data_collator,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
@@ -391,6 +392,7 @@ class Trainer:
|
||||
batch_size=self.args.eval_batch_size,
|
||||
collate_fn=self.data_collator,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
||||
|
||||
Reference in New Issue
Block a user