Add support for ZeRO-2/3 and ZeRO-offload in fairscale (#10354)
* Ass support for ZeRO-2/3 and ZeRO-offload in fairscale * Quality * Rework from review comments * Add doc * Apply suggestions from code review Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Address review comments Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@@ -25,7 +25,7 @@ from .file_utils import (
|
||||
is_torch_tpu_available,
|
||||
torch_required,
|
||||
)
|
||||
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType
|
||||
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType, ShardedDDPOption
|
||||
from .utils import logging
|
||||
|
||||
|
||||
@@ -236,9 +236,22 @@ class TrainingArguments:
|
||||
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
|
||||
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
|
||||
step can take a long time) but will not yield the same results as the interrupted training would have.
|
||||
sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
sharded_ddp (:obj:`bool`, :obj:`str` or list of :class:`~transformers.trainer_utils.ShardedDDPOption`, `optional`, defaults to :obj:`False`):
|
||||
Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed
|
||||
training only). This is an experimental feature.
|
||||
|
||||
A list of options along the following:
|
||||
|
||||
- :obj:`"simple"`: to use first instance of sharded DDP released by fairscale (:obj:`ShardedDDP`) similar
|
||||
to ZeRO-2.
|
||||
- :obj:`"zero_dp_2"`: to use the second instance of sharded DPP released by fairscale
|
||||
(:obj:`FullyShardedDDP`) in Zero-2 mode (with :obj:`reshard_after_forward=False`).
|
||||
- :obj:`"zero_dp_3"`: to use the second instance of sharded DPP released by fairscale
|
||||
(:obj:`FullyShardedDDP`) in Zero-3 mode (with :obj:`reshard_after_forward=True`).
|
||||
- :obj:`"offload"`: to add ZeRO-offload (only compatible with :obj:`"zero_dp_2"` and :obj:`"zero_dp_3"`).
|
||||
|
||||
If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty
|
||||
list for :obj:`False` and :obj:`["simple"]` for :obj:`True`.
|
||||
deepspeed (:obj:`str`, `optional`):
|
||||
Use `Deepspeed <https://github.com/microsoft/deepspeed>`__. This is an experimental feature and its API may
|
||||
evolve in the future. The value is the location of its json config file (usually ``ds_config.json``).
|
||||
@@ -443,9 +456,14 @@ class TrainingArguments:
|
||||
"help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
|
||||
},
|
||||
)
|
||||
sharded_ddp: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use sharded DDP training (in distributed training only)."},
|
||||
sharded_ddp: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"choices": ["simple", "zero_dp_2", "zero_dp_3", "zero_dp_2 offload", "zero_dp_3 offload"],
|
||||
"help": "Whether or not to use sharded DDP training (in distributed training only). The base option "
|
||||
"should be `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` "
|
||||
"like this: zero_dp_2 offload` or `zero_dp_3 offload`",
|
||||
},
|
||||
)
|
||||
deepspeed: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -535,6 +553,20 @@ class TrainingArguments:
|
||||
"Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio during training"
|
||||
)
|
||||
|
||||
if isinstance(self.sharded_ddp, bool):
|
||||
self.sharded_ddp = "simple" if self.sharded_ddp else ""
|
||||
if isinstance(self.sharded_ddp, str):
|
||||
self.sharded_ddp = [ShardedDDPOption(s) for s in self.sharded_ddp.split()]
|
||||
if self.sharded_ddp == [ShardedDDPOption.OFFLOAD]:
|
||||
raise ValueError(
|
||||
"`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or "
|
||||
'`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.'
|
||||
)
|
||||
elif len(self.sharded_ddp) > 1 and ShardedDDPOption.Simple in self.sharded_ddp:
|
||||
raise ValueError("`--sharded_ddp simple` is not compatible with any other option.")
|
||||
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
|
||||
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
|
||||
|
||||
def __repr__(self):
|
||||
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
|
||||
# those deprecated arguments are removed form TrainingArguments. (TODO: v5)
|
||||
@@ -662,7 +694,7 @@ class TrainingArguments:
|
||||
|
||||
- :obj:`ParallelMode.NOT_PARALLEL`: no parallelism (CPU or one GPU).
|
||||
- :obj:`ParallelMode.NOT_DISTRIBUTED`: several GPUs in one single process (uses :obj:`torch.nn.DataParallel`).
|
||||
- :obj:`ParallelMode.DISTRIBUTED`: several GPUs, each ahving its own process (uses
|
||||
- :obj:`ParallelMode.DISTRIBUTED`: several GPUs, each having its own process (uses
|
||||
:obj:`torch.nn.DistributedDataParallel`).
|
||||
- :obj:`ParallelMode.TPU`: several TPU cores.
|
||||
"""
|
||||
@@ -692,6 +724,8 @@ class TrainingArguments:
|
||||
for k, v in d.items():
|
||||
if isinstance(v, Enum):
|
||||
d[k] = v.value
|
||||
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
|
||||
d[k] = [x.value for x in v]
|
||||
return d
|
||||
|
||||
def to_json_string(self):
|
||||
|
||||
Reference in New Issue
Block a user