Add num workers cli arg (#7322)
* Add dataloader_num_workers to TrainingArguments This argument is meant to be used to set the number of workers for the PyTorch DataLoader. * Pass num_workers argument on DataLoader init
This commit is contained in:
@@ -352,6 +352,7 @@ class Trainer:
|
|||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
collate_fn=self.data_collator,
|
collate_fn=self.data_collator,
|
||||||
drop_last=self.args.dataloader_drop_last,
|
drop_last=self.args.dataloader_drop_last,
|
||||||
|
num_workers=self.args.dataloader_num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||||
@@ -391,6 +392,7 @@ class Trainer:
|
|||||||
batch_size=self.args.eval_batch_size,
|
batch_size=self.args.eval_batch_size,
|
||||||
collate_fn=self.data_collator,
|
collate_fn=self.data_collator,
|
||||||
drop_last=self.args.dataloader_drop_last,
|
drop_last=self.args.dataloader_drop_last,
|
||||||
|
num_workers=self.args.dataloader_num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
||||||
|
|||||||
@@ -122,6 +122,8 @@ class TrainingArguments:
|
|||||||
eval_steps (:obj:`int`, `optional`):
|
eval_steps (:obj:`int`, `optional`):
|
||||||
Number of update steps between two evaluations if :obj:`evaluation_strategy="steps"`. Will default to the
|
Number of update steps between two evaluations if :obj:`evaluation_strategy="steps"`. Will default to the
|
||||||
same value as :obj:`logging_steps` if not set.
|
same value as :obj:`logging_steps` if not set.
|
||||||
|
dataloader_num_workers (:obj:`int`, `optional`, defaults to 0):
|
||||||
|
Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the main process.
|
||||||
past_index (:obj:`int`, `optional`, defaults to -1):
|
past_index (:obj:`int`, `optional`, defaults to -1):
|
||||||
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
|
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
|
||||||
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
|
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
|
||||||
@@ -259,6 +261,10 @@ class TrainingArguments:
|
|||||||
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
|
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
|
||||||
)
|
)
|
||||||
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
|
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
|
||||||
|
dataloader_num_workers: int = field(
|
||||||
|
default=0,
|
||||||
|
metadata={"help": "Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the main process."}
|
||||||
|
)
|
||||||
|
|
||||||
past_index: int = field(
|
past_index: int = field(
|
||||||
default=-1,
|
default=-1,
|
||||||
|
|||||||
Reference in New Issue
Block a user