Add a flag for find_unused_parameters (#9820)
* Add a flag for find_unused_parameters * Apply suggestions from code review Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Remove negation Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@@ -761,18 +761,20 @@ class Trainer:
|
|||||||
elif is_sagemaker_distributed_available():
|
elif is_sagemaker_distributed_available():
|
||||||
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
|
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
|
||||||
elif self.args.local_rank != -1:
|
elif self.args.local_rank != -1:
|
||||||
|
if self.args.ddp_find_unused_parameters is not None:
|
||||||
|
find_unused_parameters = self.args.ddp_find_unused_parameters
|
||||||
|
elif isinstance(model, PreTrainedModel):
|
||||||
|
# find_unused_parameters breaks checkpointing as per
|
||||||
|
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
||||||
|
find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False)
|
||||||
|
else:
|
||||||
|
find_unused_parameters = True
|
||||||
model = torch.nn.parallel.DistributedDataParallel(
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
model,
|
model,
|
||||||
device_ids=[self.args.local_rank],
|
device_ids=[self.args.local_rank],
|
||||||
output_device=self.args.local_rank,
|
output_device=self.args.local_rank,
|
||||||
find_unused_parameters=(
|
find_unused_parameters=find_unused_parameters,
|
||||||
not getattr(model.config, "gradient_checkpointing", False)
|
|
||||||
if isinstance(model, PreTrainedModel)
|
|
||||||
else True
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
# find_unused_parameters breaks checkpointing as per
|
|
||||||
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
|
||||||
|
|
||||||
# for the rest of this function `model` is the outside model, whether it was wrapped or not
|
# for the rest of this function `model` is the outside model, whether it was wrapped or not
|
||||||
if model is not self.model:
|
if model is not self.model:
|
||||||
|
|||||||
@@ -240,6 +240,10 @@ class TrainingArguments:
|
|||||||
report_to (:obj:`List[str]`, `optional`, defaults to the list of integrations platforms installed):
|
report_to (:obj:`List[str]`, `optional`, defaults to the list of integrations platforms installed):
|
||||||
The list of integrations to report the results and logs to. Supported platforms are :obj:`"azure_ml"`,
|
The list of integrations to report the results and logs to. Supported platforms are :obj:`"azure_ml"`,
|
||||||
:obj:`"comet_ml"`, :obj:`"mlflow"`, :obj:`"tensorboard"` and :obj:`"wandb"`.
|
:obj:`"comet_ml"`, :obj:`"mlflow"`, :obj:`"tensorboard"` and :obj:`"wandb"`.
|
||||||
|
ddp_find_unused_parameters (:obj:`bool`, `optional`):
|
||||||
|
When using distributed training, the value of the flag :obj:`find_unused_parameters` passed to
|
||||||
|
:obj:`DistributedDataParallel`. Will defaut to :obj:`False` if gradient checkpointing is used, :obj:`True`
|
||||||
|
otherwise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
@@ -425,6 +429,13 @@ class TrainingArguments:
|
|||||||
report_to: Optional[List[str]] = field(
|
report_to: Optional[List[str]] = field(
|
||||||
default=None, metadata={"help": "The list of integrations to report the results and logs to."}
|
default=None, metadata={"help": "The list of integrations to report the results and logs to."}
|
||||||
)
|
)
|
||||||
|
ddp_find_unused_parameters: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "When using distributed training, the value of the flag `find_unused_parameters` passed to "
|
||||||
|
"`DistributedDataParallel`."
|
||||||
|
},
|
||||||
|
)
|
||||||
_n_gpu: int = field(init=False, repr=False, default=-1)
|
_n_gpu: int = field(init=False, repr=False, default=-1)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user