Add support for ZeRO-2/3 and ZeRO-offload in fairscale (#10354)

* Ass support for ZeRO-2/3 and ZeRO-offload in fairscale

* Quality

* Rework from review comments

* Add doc

* Apply suggestions from code review

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Address review comments

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger
2021-02-25 11:07:53 -05:00
committed by GitHub
parent 88cc26dcd1
commit 9d14be5c20
5 changed files with 193 additions and 46 deletions

View File

@@ -93,6 +93,7 @@ from .trainer_utils import (
EvalPrediction,
HPSearchBackend,
PredictionOutput,
ShardedDDPOption,
TrainerMemoryTracker,
TrainOutput,
default_compute_objective,
@@ -131,10 +132,16 @@ if is_torch_tpu_available():
import torch_xla.distributed.parallel_loader as pl
if is_fairscale_available():
import fairscale
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
if version.parse(fairscale.__version__) >= version.parse("0.3"):
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
else:
FullyShardedDDP = None
if is_sagemaker_distributed_available():
import smdistributed.dataparallel.torch.distributed as dist
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
@@ -277,9 +284,38 @@ class Trainer:
else:
self.is_model_parallel = False
# Setup Sharded DDP training
self.sharded_ddp = None
if len(args.sharded_ddp) > 0:
if args.deepspeed:
raise ValueError(
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
)
if args.local_rank == -1:
raise ValueError("Using sharded DDP only works in distributed training.")
elif not is_fairscale_available():
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
raise ImportError(
"Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
)
elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.SIMPLE
elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3
# one place to sort out whether to place the model on device or not
self.place_model_on_device = args.place_model_on_device
if self.is_model_parallel or (args.deepspeed and args.do_train) or (args.fp16_full_eval and not args.do_train):
if (
self.is_model_parallel
or (args.deepspeed and args.do_train)
or (args.fp16_full_eval and not args.do_train)
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
):
self.place_model_on_device = False
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
@@ -346,21 +382,6 @@ class Trainer:
if isinstance(eval_dataset, datasets.Dataset):
self._remove_unused_columns(self.eval_dataset, description="evaluation")
# Setup Sharded DDP training
self.sharded_dpp = False
if args.sharded_ddp:
if args.deepspeed:
raise ValueError(
"Using --sharded_ddp together with --deepspeed is not possible, deactivate one of those flags."
)
if args.local_rank == -1:
raise ValueError("Using sharded DDP only works in distributed training.")
elif not is_fairscale_available():
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
else:
self.sharded_dpp = True
# Mixed precision setup
self.use_apex = False
self.use_amp = False
@@ -376,7 +397,7 @@ class Trainer:
if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16
if self.fp16_backend == "amp":
self.use_amp = True
self.scaler = ShardedGradScaler() if self.sharded_dpp else torch.cuda.amp.GradScaler()
self.scaler = ShardedGradScaler() if self.sharded_ddp is not None else torch.cuda.amp.GradScaler()
else:
if not is_apex_available():
raise ImportError(
@@ -619,7 +640,7 @@ class Trainer:
"eps": self.args.adam_epsilon,
}
optimizer_kwargs["lr"] = self.args.learning_rate
if self.sharded_dpp:
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
@@ -737,8 +758,19 @@ class Trainer:
return model
# Distributed training (should be after apex fp16 initialization)
if self.sharded_dpp:
model = ShardedDDP(model, self.optimizer)
if self.sharded_ddp is not None:
# Sharded DDP!
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
model = ShardedDDP(model, self.optimizer)
else:
mixed_precision = self.args.fp16
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
# XXX: Breaking the self.model convention but I see no way around it for now.
self.model = model = FullyShardedDDP(
model, mixed_precision=mixed_precision, reshard_after_forward=zero_3, cpu_offload=cpu_offload
).to(self.args.device)
elif is_sagemaker_distributed_available():
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
elif self.args.local_rank != -1:
@@ -855,6 +887,7 @@ class Trainer:
num_train_epochs = 1
num_update_steps_per_epoch = max_steps
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
if self.args.deepspeed:
model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps)
self.model = model.module
@@ -862,7 +895,7 @@ class Trainer:
self.deepspeed = model # DeepSpeedEngine object
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
else:
elif not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState()
@@ -877,6 +910,9 @@ class Trainer:
if model is not self.model:
self.model_wrapped = model
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# important: at this point:
# self.model is the Transformers Model
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
@@ -1026,6 +1062,9 @@ class Trainer:
if hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(self.args.max_grad_norm)
elif hasattr(model, "clip_grad_norm_"):
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping
model.clip_grad_norm_(self.args.max_grad_norm)
else:
# Revert to normal clipping otherwise, handling Apex or full precision
torch.nn.utils.clip_grad_norm_(
@@ -1148,8 +1187,8 @@ class Trainer:
def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# want to save.
assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"
# want to save except FullyShardedDDP.
# assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"
# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
@@ -1173,7 +1212,7 @@ class Trainer:
self.deepspeed.save_checkpoint(output_dir)
# Save optimizer and scheduler
if self.sharded_dpp:
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer.consolidate_state_dict()
if is_torch_tpu_available():
@@ -1479,7 +1518,11 @@ class Trainer:
# They can then be reloaded using `from_pretrained()`
xm.rendezvous("saving_checkpoint")
if not isinstance(self.model, PreTrainedModel):
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if isinstance(_model_unwrap(self.model), PreTrainedModel):
if xm.is_master_ordinal():
_model_unwrap(self.model).config.save_pretrained(output_dir)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
@@ -1494,7 +1537,10 @@ class Trainer:
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if isinstance(_model_unwrap(self.model), PreTrainedModel):
_model_unwrap(self.model).config.save_pretrained(output_dir)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:

View File

@@ -421,3 +421,10 @@ class TrainerMemoryTracker:
# init doesn't have metrics to update so we just save that data for later stages to retrieve
if metrics is not None:
self.update_metrics(stage, metrics)
class ShardedDDPOption(ExplicitEnum):
SIMPLE = "simple"
ZERO_DP_2 = "zero2"
ZERO_DP_3 = "zero3"
OFFLOAD = "offload"

View File

@@ -25,7 +25,7 @@ from .file_utils import (
is_torch_tpu_available,
torch_required,
)
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType, ShardedDDPOption
from .utils import logging
@@ -236,9 +236,22 @@ class TrainingArguments:
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
step can take a long time) but will not yield the same results as the interrupted training would have.
sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`):
sharded_ddp (:obj:`bool`, :obj:`str` or list of :class:`~transformers.trainer_utils.ShardedDDPOption`, `optional`, defaults to :obj:`False`):
Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed
training only). This is an experimental feature.
A list of options along the following:
- :obj:`"simple"`: to use first instance of sharded DDP released by fairscale (:obj:`ShardedDDP`) similar
to ZeRO-2.
- :obj:`"zero_dp_2"`: to use the second instance of sharded DPP released by fairscale
(:obj:`FullyShardedDDP`) in Zero-2 mode (with :obj:`reshard_after_forward=False`).
- :obj:`"zero_dp_3"`: to use the second instance of sharded DPP released by fairscale
(:obj:`FullyShardedDDP`) in Zero-3 mode (with :obj:`reshard_after_forward=True`).
- :obj:`"offload"`: to add ZeRO-offload (only compatible with :obj:`"zero_dp_2"` and :obj:`"zero_dp_3"`).
If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty
list for :obj:`False` and :obj:`["simple"]` for :obj:`True`.
deepspeed (:obj:`str`, `optional`):
Use `Deepspeed <https://github.com/microsoft/deepspeed>`__. This is an experimental feature and its API may
evolve in the future. The value is the location of its json config file (usually ``ds_config.json``).
@@ -443,9 +456,14 @@ class TrainingArguments:
"help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
},
)
sharded_ddp: bool = field(
default=False,
metadata={"help": "Whether or not to use sharded DDP training (in distributed training only)."},
sharded_ddp: str = field(
default="",
metadata={
"choices": ["simple", "zero_dp_2", "zero_dp_3", "zero_dp_2 offload", "zero_dp_3 offload"],
"help": "Whether or not to use sharded DDP training (in distributed training only). The base option "
"should be `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` "
"like this: zero_dp_2 offload` or `zero_dp_3 offload`",
},
)
deepspeed: Optional[str] = field(
default=None,
@@ -535,6 +553,20 @@ class TrainingArguments:
"Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio during training"
)
if isinstance(self.sharded_ddp, bool):
self.sharded_ddp = "simple" if self.sharded_ddp else ""
if isinstance(self.sharded_ddp, str):
self.sharded_ddp = [ShardedDDPOption(s) for s in self.sharded_ddp.split()]
if self.sharded_ddp == [ShardedDDPOption.OFFLOAD]:
raise ValueError(
"`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or "
'`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.'
)
elif len(self.sharded_ddp) > 1 and ShardedDDPOption.Simple in self.sharded_ddp:
raise ValueError("`--sharded_ddp simple` is not compatible with any other option.")
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
def __repr__(self):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
# those deprecated arguments are removed form TrainingArguments. (TODO: v5)
@@ -662,7 +694,7 @@ class TrainingArguments:
- :obj:`ParallelMode.NOT_PARALLEL`: no parallelism (CPU or one GPU).
- :obj:`ParallelMode.NOT_DISTRIBUTED`: several GPUs in one single process (uses :obj:`torch.nn.DataParallel`).
- :obj:`ParallelMode.DISTRIBUTED`: several GPUs, each ahving its own process (uses
- :obj:`ParallelMode.DISTRIBUTED`: several GPUs, each having its own process (uses
:obj:`torch.nn.DistributedDataParallel`).
- :obj:`ParallelMode.TPU`: several TPU cores.
"""
@@ -692,6 +724,8 @@ class TrainingArguments:
for k, v in d.items():
if isinstance(v, Enum):
d[k] = v.value
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
d[k] = [x.value for x in v]
return d
def to_json_string(self):