PyTorch FSDP integration in Trainer (#17136)
* PyTorch FSDP integration in Trainer * reformatting make style and make quality are now compliant. * Updating dependency check * Trigger CI Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
This commit is contained in:
committed by
GitHub
parent
dc3645dc9c
commit
05fc1766ff
@@ -540,6 +540,42 @@ Known caveats:
|
||||
`FullyShardedDataParallelism` of fairscale. It should be used with the option `auto_wrap` if you are not
|
||||
doing this yourself: `--sharded_ddp "zero_dp_3 auto_wrap"`.
|
||||
|
||||
### PyTorch Fully Sharded Data parallel
|
||||
|
||||
To accelerate training huge models on larger batch sizes, we can use a fully sharded data parallel model.
|
||||
This type of data parallel paradigm enables fitting more data and larger models by sharding the optimizer states, gradients and parameters.
|
||||
To read more about it and the benefits, check out the [Fully Sharded Data Parallel blog](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/).
|
||||
We have integrated the latest PyTorch's Fully Sharded Data Parallel (FSDP) training feature.
|
||||
All you need to do is enable it through the config.
|
||||
|
||||
**Required PyTorch version for FSDP support**: PyTorch Nightly (or 1.12.0 if you read this after it has been released)
|
||||
as the model saving with FSDP activated is only available with recent fixes.
|
||||
|
||||
**Usage**:
|
||||
|
||||
- Make sure you have added the distributed launcher
|
||||
`-m torch.distributed.launch --nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE` if you haven't been using it already.
|
||||
|
||||
- **Sharding Strategy**:
|
||||
- FULL_SHARD : Shards optimizer states + gradients + model parameters across data parallel workers/GPUs.
|
||||
For this, add `--fsdp full_shard` to the command line arguments.
|
||||
- SHARD_GRAD_OP : Shards optimizer states + gradients across data parallel workers/GPUs.
|
||||
For this, add `--fsdp shard_grad_op` to the command line arguments.
|
||||
- To offload the parameters and gradients to the CPU,
|
||||
add `--fsdp "full_shard offload"` or `--fsdp "shard_grad_op offload"` to the command line arguments.
|
||||
- To automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`,
|
||||
add `--fsdp "full_shard auto_wrap"` or `--fsdp "shard_grad_op auto_wrap"` to the command line arguments.
|
||||
- To enable both CPU offloading and auto wrapping,
|
||||
add `--fsdp "full_shard offload auto_wrap"` or `--fsdp "shard_grad_op offload auto_wrap"` to the command line arguments.
|
||||
- If auto wrapping is enabled, please add `--fsdp_min_num_params <number>` to command line arguments.
|
||||
It specifies FSDP's minimum number of parameters for Default Auto Wrapping.
|
||||
|
||||
**Few caveats to be aware of**
|
||||
- Mixed precision is currently not supported with FSDP as we wait for PyTorch to fix support for it.
|
||||
More details in this [issues](https://github.com/pytorch/pytorch/issues/75676).
|
||||
- FSDP currently doesn't support multiple parameter groups.
|
||||
More details mentioned in this [issue](https://github.com/pytorch/pytorch/issues/76501)
|
||||
(`The original model parameters' .grads are not set, meaning that they cannot be optimized separately (which is why we cannot support multiple parameter groups)`).
|
||||
|
||||
Sections that were moved:
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
@@ -103,6 +104,7 @@ from .trainer_utils import (
|
||||
BestRun,
|
||||
EvalLoopOutput,
|
||||
EvalPrediction,
|
||||
FSDPOption,
|
||||
HPSearchBackend,
|
||||
HubStrategy,
|
||||
IntervalStrategy,
|
||||
@@ -340,6 +342,10 @@ class Trainer:
|
||||
raise ValueError(
|
||||
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
|
||||
)
|
||||
if len(args.fsdp) > 0:
|
||||
raise ValueError(
|
||||
"Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
|
||||
)
|
||||
|
||||
if args.local_rank == -1:
|
||||
raise ValueError("Using sharded DDP only works in distributed training.")
|
||||
@@ -357,6 +363,30 @@ class Trainer:
|
||||
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
|
||||
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3
|
||||
|
||||
self.fsdp = None
|
||||
if len(args.fsdp) > 0:
|
||||
if args.deepspeed:
|
||||
raise ValueError(
|
||||
"Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
|
||||
)
|
||||
if args.local_rank == -1:
|
||||
raise ValueError("Using fsdp only works in distributed training.")
|
||||
|
||||
# dep_version_check("torch>=1.12.0.dev20220418+cu113")
|
||||
# Would have to update setup.py with torch>=1.12.0.dev20220418+cu113
|
||||
# which isn't ideally given that it's a dev version
|
||||
# and it will force people not using FSDP to also use torch>=1.12.0.dev20220418+cu113
|
||||
# below is the current alternative.
|
||||
if version.parse(torch.__version__) < version.parse("1.12.0.dev20220418+cu113"):
|
||||
raise ValueError("FSDP requires PyTorch >= 1.12.0.dev20220418+cu113")
|
||||
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
|
||||
|
||||
if FSDPOption.FULL_SHARD in args.fsdp:
|
||||
self.fsdp = ShardingStrategy.FULL_SHARD
|
||||
elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
|
||||
self.fsdp = ShardingStrategy.SHARD_GRAD_OP
|
||||
|
||||
# one place to sort out whether to place the model on device or not
|
||||
# postpone switching model to cuda when:
|
||||
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
|
||||
@@ -364,12 +394,14 @@ class Trainer:
|
||||
# and we only use deepspeed for training at the moment
|
||||
# 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
|
||||
# 4. Sharded DDP - same as MP
|
||||
# 5. FSDP - same as MP
|
||||
self.place_model_on_device = args.place_model_on_device
|
||||
if (
|
||||
self.is_model_parallel
|
||||
or args.deepspeed
|
||||
or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
|
||||
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
|
||||
or (self.fsdp is not None)
|
||||
):
|
||||
self.place_model_on_device = False
|
||||
|
||||
@@ -398,11 +430,11 @@ class Trainer:
|
||||
"Passing a `model_init` is incompatible with providing the `optimizers` argument. "
|
||||
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||
)
|
||||
if (self.sharded_ddp is not None or args.deepspeed) and (
|
||||
if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and (
|
||||
self.optimizer is not None or self.lr_scheduler is not None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Passing `optimizers` is not allowed if Fairscale or Deepspeed is enabled."
|
||||
"Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
|
||||
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||
)
|
||||
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
||||
@@ -450,6 +482,11 @@ class Trainer:
|
||||
self.use_amp = False
|
||||
|
||||
if args.fp16 or args.bf16:
|
||||
if self.fsdp is not None:
|
||||
raise ValueError(
|
||||
"Mixed precision is currently not supported for FSDP."
|
||||
"Please do not set arguments related to `mixed_precision`"
|
||||
)
|
||||
if args.half_precision_backend == "auto":
|
||||
if _is_native_amp_available:
|
||||
args.half_precision_backend = "amp"
|
||||
@@ -1102,6 +1139,33 @@ class Trainer:
|
||||
cpu_offload=cpu_offload,
|
||||
).to(self.args.device)
|
||||
|
||||
# Distributed training using PyTorch FSDP
|
||||
if self.fsdp is not None:
|
||||
# PyTorch FSDP!
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.wrap import default_auto_wrap_policy
|
||||
|
||||
if FSDPOption.OFFLOAD in self.args.fsdp:
|
||||
cpu_offload = CPUOffload(offload_params=True)
|
||||
else:
|
||||
cpu_offload = CPUOffload(offload_params=False)
|
||||
|
||||
auto_wrap_policy = None
|
||||
if FSDPOption.AUTO_WRAP in self.args.fsdp:
|
||||
if self.args.fsdp_min_num_params > 0:
|
||||
auto_wrap_policy = functools.partial(
|
||||
default_auto_wrap_policy, min_num_params=self.args.fsdp_min_num_params
|
||||
)
|
||||
|
||||
if type(model) != FSDP:
|
||||
# XXX: Breaking the self.model convention but I see no way around it for now.
|
||||
self.model = model = FSDP(
|
||||
model, sharding_strategy=self.fsdp, cpu_offload=cpu_offload, auto_wrap_policy=auto_wrap_policy
|
||||
)
|
||||
if FSDPOption.OFFLOAD not in self.args.fsdp:
|
||||
model.to(self.args.device)
|
||||
|
||||
elif is_sagemaker_dp_enabled():
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
|
||||
@@ -1253,7 +1317,10 @@ class Trainer:
|
||||
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
||||
|
||||
delay_optimizer_creation = (
|
||||
self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE or is_sagemaker_mp_enabled()
|
||||
self.sharded_ddp is not None
|
||||
and self.sharded_ddp != ShardedDDPOption.SIMPLE
|
||||
or is_sagemaker_mp_enabled()
|
||||
or self.fsdp is not None
|
||||
)
|
||||
if args.deepspeed:
|
||||
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
|
||||
@@ -2138,7 +2205,9 @@ class Trainer:
|
||||
if self.args.should_save:
|
||||
self._save(output_dir, state_dict=state_dict)
|
||||
elif (
|
||||
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
|
||||
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
|
||||
or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
|
||||
or self.fsdp is not None
|
||||
):
|
||||
state_dict = self.model.state_dict()
|
||||
|
||||
|
||||
@@ -582,3 +582,10 @@ class ShardedDDPOption(ExplicitEnum):
|
||||
ZERO_DP_3 = "zero_dp_3"
|
||||
OFFLOAD = "offload"
|
||||
AUTO_WRAP = "auto_wrap"
|
||||
|
||||
|
||||
class FSDPOption(ExplicitEnum):
|
||||
FULL_SHARD = "full_shard"
|
||||
SHARD_GRAD_OP = "shard_grad_op"
|
||||
OFFLOAD = "offload"
|
||||
AUTO_WRAP = "auto_wrap"
|
||||
|
||||
@@ -23,7 +23,14 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .debug_utils import DebugOption
|
||||
from .trainer_utils import EvaluationStrategy, HubStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption
|
||||
from .trainer_utils import (
|
||||
EvaluationStrategy,
|
||||
FSDPOption,
|
||||
HubStrategy,
|
||||
IntervalStrategy,
|
||||
SchedulerType,
|
||||
ShardedDDPOption,
|
||||
)
|
||||
from .utils import (
|
||||
ExplicitEnum,
|
||||
cached_property,
|
||||
@@ -331,6 +338,18 @@ class TrainingArguments:
|
||||
|
||||
If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty
|
||||
list for `False` and `["simple"]` for `True`.
|
||||
fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `False`):
|
||||
Use PyTorch Distributed Parallel Training (in distributed training only).
|
||||
|
||||
A list of options along the following:
|
||||
|
||||
- `"full_shard"`: Shard parameters, gradients and optimizer states.
|
||||
- `"shard_grad_op"`: Shard optimizer states and gradients.
|
||||
- `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and
|
||||
`"shard_grad_op"`).
|
||||
- `"auto_wrap"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`.
|
||||
fsdp_min_num_params (`int`, *optional*, defaults to `0`):
|
||||
FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is passed).
|
||||
deepspeed (`str` or `dict`, *optional*):
|
||||
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
|
||||
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
|
||||
@@ -674,10 +693,25 @@ class TrainingArguments:
|
||||
metadata={
|
||||
"help": "Whether or not to use sharded DDP training (in distributed training only). The base option "
|
||||
"should be `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` "
|
||||
"like this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or "
|
||||
"like this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or `zero_dp_3` "
|
||||
"with the same syntax: zero_dp_2 auto_wrap` or `zero_dp_3 auto_wrap`.",
|
||||
},
|
||||
)
|
||||
fsdp: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training only). The base option "
|
||||
"should be `full_shard` or `shard_grad_op` and you can add CPU-offload to `full_shard` or `shard_grad_op` "
|
||||
"like this: full_shard offload` or `shard_grad_op offload`. You can add auto-wrap to `full_shard` or `shard_grad_op` "
|
||||
"with the same syntax: full_shard auto_wrap` or `shard_grad_op auto_wrap`.",
|
||||
},
|
||||
)
|
||||
fsdp_min_num_params: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is passed)."
|
||||
},
|
||||
)
|
||||
deepspeed: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@@ -931,6 +965,21 @@ class TrainingArguments:
|
||||
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
|
||||
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
|
||||
|
||||
if isinstance(self.fsdp, bool):
|
||||
self.fsdp = "full_shard" if self.fsdp else ""
|
||||
if isinstance(self.fsdp, str):
|
||||
self.fsdp = [FSDPOption(s) for s in self.fsdp.split()]
|
||||
if self.fsdp == [FSDPOption.OFFLOAD]:
|
||||
raise ValueError(
|
||||
"`--fsdp offload` can't work on its own. It needs to be added to `--fsdp full_shard` or "
|
||||
'`--fsdp shard_grad_op`. For example, `--fsdp "full_shard offload"`.'
|
||||
)
|
||||
elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.sharded_ddp:
|
||||
raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.")
|
||||
|
||||
if len(self.fsdp) == 0 and self.fsdp_min_num_params > 0:
|
||||
warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.")
|
||||
|
||||
if self.tpu_metrics_debug:
|
||||
warnings.warn(
|
||||
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--debug tpu_metrics_debug` instead",
|
||||
|
||||
Reference in New Issue
Block a user