From ee88ae59940fd4b2c8fc119373143d7a1175c651 Mon Sep 17 00:00:00 2001 From: Teven Date: Fri, 16 Jun 2023 15:14:03 -0400 Subject: [PATCH] Adding ddp_broadcast_buffers argument to Trainer (#24326) adding ddp_broadcast_buffers argument --- src/transformers/trainer.py | 3 +++ src/transformers/training_args.py | 12 ++++++++++++ 2 files changed, 15 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f47b57d6d9..0eec118cf4 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1450,6 +1450,9 @@ class Trainer: if self.args.ddp_bucket_cap_mb is not None: kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb + if self.args.ddp_broadcast_buffers is not None: + kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers + self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) return model diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 5fe680a802..71d75c6e0e 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -505,6 +505,9 @@ class TrainingArguments: `DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise. ddp_bucket_cap_mb (`int`, *optional*): When using distributed training, the value of the flag `bucket_cap_mb` passed to `DistributedDataParallel`. + ddp_broadcast_buffers (`bool`, *optional*): + When using distributed training, the value of the flag `broadcast_buffers` passed to + `DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise. dataloader_pin_memory (`bool`, *optional*, defaults to `True`): Whether you want to pin memory in data loaders or not. Will default to `True`. skip_memory_metrics (`bool`, *optional*, defaults to `True`): @@ -1045,6 +1048,15 @@ class TrainingArguments: ) }, ) + ddp_broadcast_buffers: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "When using distributed training, the value of the flag `broadcast_buffers` passed to " + "`DistributedDataParallel`." + ) + }, + ) dataloader_pin_memory: bool = field( default=True, metadata={"help": "Whether or not to pin memory for DataLoader."} )