From 17099ebd58a2671d702d5e37ab0fe2cfbf9b8ee2 Mon Sep 17 00:00:00 2001 From: Chady Kamar Date: Tue, 22 Sep 2020 14:44:42 -0400 Subject: [PATCH] Add num workers cli arg (#7322) * Add dataloader_num_workers to TrainingArguments This argument is meant to be used to set the number of workers for the PyTorch DataLoader. * Pass num_workers argument on DataLoader init --- src/transformers/trainer.py | 2 ++ src/transformers/training_args.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index de00161050..cdfce528a8 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -352,6 +352,7 @@ class Trainer: sampler=train_sampler, collate_fn=self.data_collator, drop_last=self.args.dataloader_drop_last, + num_workers=self.args.dataloader_num_workers, ) def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]: @@ -391,6 +392,7 @@ class Trainer: batch_size=self.args.eval_batch_size, collate_fn=self.data_collator, drop_last=self.args.dataloader_drop_last, + num_workers=self.args.dataloader_num_workers, ) def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 1000f09144..2344098b42 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -122,6 +122,8 @@ class TrainingArguments: eval_steps (:obj:`int`, `optional`): Number of update steps between two evaluations if :obj:`evaluation_strategy="steps"`. Will default to the same value as :obj:`logging_steps` if not set. + dataloader_num_workers (:obj:`int`, `optional`, defaults to 0): + Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the main process. past_index (:obj:`int`, `optional`, defaults to -1): Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can make use of the past hidden states for their predictions. If this argument is set to a positive int, the @@ -259,6 +261,10 @@ class TrainingArguments: default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} ) eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."}) + dataloader_num_workers: int = field( + default=0, + metadata={"help": "Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the main process."} + ) past_index: int = field( default=-1,