From c7b7bd9963945ee3b10de1c2d2243d23cb621a98 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 27 Jan 2021 06:18:06 -0500 Subject: [PATCH] Add a flag for find_unused_parameters (#9820) * Add a flag for find_unused_parameters * Apply suggestions from code review Co-authored-by: Stas Bekman * Remove negation Co-authored-by: Stas Bekman --- src/transformers/trainer.py | 16 +++++++++------- src/transformers/training_args.py | 11 +++++++++++ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e64a1e3fb7..701657fdb4 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -761,18 +761,20 @@ class Trainer: elif is_sagemaker_distributed_available(): model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False) elif self.args.local_rank != -1: + if self.args.ddp_find_unused_parameters is not None: + find_unused_parameters = self.args.ddp_find_unused_parameters + elif isinstance(model, PreTrainedModel): + # find_unused_parameters breaks checkpointing as per + # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 + find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False) + else: + find_unused_parameters = True model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[self.args.local_rank], output_device=self.args.local_rank, - find_unused_parameters=( - not getattr(model.config, "gradient_checkpointing", False) - if isinstance(model, PreTrainedModel) - else True - ), + find_unused_parameters=find_unused_parameters, ) - # find_unused_parameters breaks checkpointing as per - # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 50e7a00ca6..6811c61154 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -240,6 +240,10 @@ class TrainingArguments: report_to (:obj:`List[str]`, `optional`, defaults to the list of integrations platforms installed): The list of integrations to report the results and logs to. Supported platforms are :obj:`"azure_ml"`, :obj:`"comet_ml"`, :obj:`"mlflow"`, :obj:`"tensorboard"` and :obj:`"wandb"`. + ddp_find_unused_parameters (:obj:`bool`, `optional`): + When using distributed training, the value of the flag :obj:`find_unused_parameters` passed to + :obj:`DistributedDataParallel`. Will defaut to :obj:`False` if gradient checkpointing is used, :obj:`True` + otherwise. """ output_dir: str = field( @@ -425,6 +429,13 @@ class TrainingArguments: report_to: Optional[List[str]] = field( default=None, metadata={"help": "The list of integrations to report the results and logs to."} ) + ddp_find_unused_parameters: Optional[bool] = field( + default=None, + metadata={ + "help": "When using distributed training, the value of the flag `find_unused_parameters` passed to " + "`DistributedDataParallel`." + }, + ) _n_gpu: int = field(init=False, repr=False, default=-1) def __post_init__(self):