Add num workers cli arg (#7322)

* Add dataloader_num_workers to TrainingArguments

This argument is meant to be used to set the
number of workers for the PyTorch DataLoader.

* Pass num_workers argument on DataLoader init
This commit is contained in:
Chady Kamar
2020-09-22 14:44:42 -04:00
committed by GitHub
parent 25b0463d0b
commit 17099ebd58
2 changed files with 8 additions and 0 deletions

View File

@@ -352,6 +352,7 @@ class Trainer:
sampler=train_sampler,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
)
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
@@ -391,6 +392,7 @@ class Trainer:
batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
)
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: