Adding ddp_broadcast_buffers argument to Trainer (#24326)
adding ddp_broadcast_buffers argument
This commit is contained in:
@@ -1450,6 +1450,9 @@ class Trainer:
|
|||||||
if self.args.ddp_bucket_cap_mb is not None:
|
if self.args.ddp_bucket_cap_mb is not None:
|
||||||
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
|
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
|
||||||
|
|
||||||
|
if self.args.ddp_broadcast_buffers is not None:
|
||||||
|
kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers
|
||||||
|
|
||||||
self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
|
self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -505,6 +505,9 @@ class TrainingArguments:
|
|||||||
`DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
|
`DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
|
||||||
ddp_bucket_cap_mb (`int`, *optional*):
|
ddp_bucket_cap_mb (`int`, *optional*):
|
||||||
When using distributed training, the value of the flag `bucket_cap_mb` passed to `DistributedDataParallel`.
|
When using distributed training, the value of the flag `bucket_cap_mb` passed to `DistributedDataParallel`.
|
||||||
|
ddp_broadcast_buffers (`bool`, *optional*):
|
||||||
|
When using distributed training, the value of the flag `broadcast_buffers` passed to
|
||||||
|
`DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
|
||||||
dataloader_pin_memory (`bool`, *optional*, defaults to `True`):
|
dataloader_pin_memory (`bool`, *optional*, defaults to `True`):
|
||||||
Whether you want to pin memory in data loaders or not. Will default to `True`.
|
Whether you want to pin memory in data loaders or not. Will default to `True`.
|
||||||
skip_memory_metrics (`bool`, *optional*, defaults to `True`):
|
skip_memory_metrics (`bool`, *optional*, defaults to `True`):
|
||||||
@@ -1045,6 +1048,15 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
ddp_broadcast_buffers: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"When using distributed training, the value of the flag `broadcast_buffers` passed to "
|
||||||
|
"`DistributedDataParallel`."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
dataloader_pin_memory: bool = field(
|
dataloader_pin_memory: bool = field(
|
||||||
default=True, metadata={"help": "Whether or not to pin memory for DataLoader."}
|
default=True, metadata={"help": "Whether or not to pin memory for DataLoader."}
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user