Pin memory in Trainer by default (#9857)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@@ -485,6 +485,7 @@ class Trainer:
|
|||||||
collate_fn=self.data_collator,
|
collate_fn=self.data_collator,
|
||||||
drop_last=self.args.dataloader_drop_last,
|
drop_last=self.args.dataloader_drop_last,
|
||||||
num_workers=self.args.dataloader_num_workers,
|
num_workers=self.args.dataloader_num_workers,
|
||||||
|
pin_memory=self.args.pin_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||||
@@ -522,6 +523,7 @@ class Trainer:
|
|||||||
collate_fn=self.data_collator,
|
collate_fn=self.data_collator,
|
||||||
drop_last=self.args.dataloader_drop_last,
|
drop_last=self.args.dataloader_drop_last,
|
||||||
num_workers=self.args.dataloader_num_workers,
|
num_workers=self.args.dataloader_num_workers,
|
||||||
|
pin_memory=self.args.pin_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
||||||
@@ -548,6 +550,7 @@ class Trainer:
|
|||||||
batch_size=self.args.eval_batch_size,
|
batch_size=self.args.eval_batch_size,
|
||||||
collate_fn=self.data_collator,
|
collate_fn=self.data_collator,
|
||||||
drop_last=self.args.dataloader_drop_last,
|
drop_last=self.args.dataloader_drop_last,
|
||||||
|
pin_memory=self.args.pin_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
||||||
@@ -1140,7 +1143,7 @@ class Trainer:
|
|||||||
direction: str = "minimize",
|
direction: str = "minimize",
|
||||||
backend: Optional[Union["str", HPSearchBackend]] = None,
|
backend: Optional[Union["str", HPSearchBackend]] = None,
|
||||||
hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
|
hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> BestRun:
|
) -> BestRun:
|
||||||
"""
|
"""
|
||||||
Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by
|
Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by
|
||||||
|
|||||||
@@ -242,8 +242,10 @@ class TrainingArguments:
|
|||||||
:obj:`"comet_ml"`, :obj:`"mlflow"`, :obj:`"tensorboard"` and :obj:`"wandb"`.
|
:obj:`"comet_ml"`, :obj:`"mlflow"`, :obj:`"tensorboard"` and :obj:`"wandb"`.
|
||||||
ddp_find_unused_parameters (:obj:`bool`, `optional`):
|
ddp_find_unused_parameters (:obj:`bool`, `optional`):
|
||||||
When using distributed training, the value of the flag :obj:`find_unused_parameters` passed to
|
When using distributed training, the value of the flag :obj:`find_unused_parameters` passed to
|
||||||
:obj:`DistributedDataParallel`. Will defaut to :obj:`False` if gradient checkpointing is used, :obj:`True`
|
:obj:`DistributedDataParallel`. Will default to :obj:`False` if gradient checkpointing is used, :obj:`True`
|
||||||
otherwise.
|
otherwise.
|
||||||
|
pin_memory (:obj:`bool`, `optional`, defaults to :obj:`True`)):
|
||||||
|
Whether you want to pin memory in data loaders or not. Will default to :obj:`True`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
@@ -436,6 +438,7 @@ class TrainingArguments:
|
|||||||
"`DistributedDataParallel`."
|
"`DistributedDataParallel`."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
pin_memory: bool = field(default=True, metadata={"help": "Whether or not to pin memory for data loaders."})
|
||||||
_n_gpu: int = field(init=False, repr=False, default=-1)
|
_n_gpu: int = field(init=False, repr=False, default=-1)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user