PyTorch FSDP integration in Trainer (#17136)
* PyTorch FSDP integration in Trainer * reformatting make style and make quality are now compliant. * Updating dependency check * Trigger CI Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
This commit is contained in:
committed by
GitHub
parent
dc3645dc9c
commit
05fc1766ff
@@ -540,6 +540,42 @@ Known caveats:
|
|||||||
`FullyShardedDataParallelism` of fairscale. It should be used with the option `auto_wrap` if you are not
|
`FullyShardedDataParallelism` of fairscale. It should be used with the option `auto_wrap` if you are not
|
||||||
doing this yourself: `--sharded_ddp "zero_dp_3 auto_wrap"`.
|
doing this yourself: `--sharded_ddp "zero_dp_3 auto_wrap"`.
|
||||||
|
|
||||||
|
### PyTorch Fully Sharded Data parallel
|
||||||
|
|
||||||
|
To accelerate training huge models on larger batch sizes, we can use a fully sharded data parallel model.
|
||||||
|
This type of data parallel paradigm enables fitting more data and larger models by sharding the optimizer states, gradients and parameters.
|
||||||
|
To read more about it and the benefits, check out the [Fully Sharded Data Parallel blog](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/).
|
||||||
|
We have integrated the latest PyTorch's Fully Sharded Data Parallel (FSDP) training feature.
|
||||||
|
All you need to do is enable it through the config.
|
||||||
|
|
||||||
|
**Required PyTorch version for FSDP support**: PyTorch Nightly (or 1.12.0 if you read this after it has been released)
|
||||||
|
as the model saving with FSDP activated is only available with recent fixes.
|
||||||
|
|
||||||
|
**Usage**:
|
||||||
|
|
||||||
|
- Make sure you have added the distributed launcher
|
||||||
|
`-m torch.distributed.launch --nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE` if you haven't been using it already.
|
||||||
|
|
||||||
|
- **Sharding Strategy**:
|
||||||
|
- FULL_SHARD : Shards optimizer states + gradients + model parameters across data parallel workers/GPUs.
|
||||||
|
For this, add `--fsdp full_shard` to the command line arguments.
|
||||||
|
- SHARD_GRAD_OP : Shards optimizer states + gradients across data parallel workers/GPUs.
|
||||||
|
For this, add `--fsdp shard_grad_op` to the command line arguments.
|
||||||
|
- To offload the parameters and gradients to the CPU,
|
||||||
|
add `--fsdp "full_shard offload"` or `--fsdp "shard_grad_op offload"` to the command line arguments.
|
||||||
|
- To automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`,
|
||||||
|
add `--fsdp "full_shard auto_wrap"` or `--fsdp "shard_grad_op auto_wrap"` to the command line arguments.
|
||||||
|
- To enable both CPU offloading and auto wrapping,
|
||||||
|
add `--fsdp "full_shard offload auto_wrap"` or `--fsdp "shard_grad_op offload auto_wrap"` to the command line arguments.
|
||||||
|
- If auto wrapping is enabled, please add `--fsdp_min_num_params <number>` to command line arguments.
|
||||||
|
It specifies FSDP's minimum number of parameters for Default Auto Wrapping.
|
||||||
|
|
||||||
|
**Few caveats to be aware of**
|
||||||
|
- Mixed precision is currently not supported with FSDP as we wait for PyTorch to fix support for it.
|
||||||
|
More details in this [issues](https://github.com/pytorch/pytorch/issues/75676).
|
||||||
|
- FSDP currently doesn't support multiple parameter groups.
|
||||||
|
More details mentioned in this [issue](https://github.com/pytorch/pytorch/issues/76501)
|
||||||
|
(`The original model parameters' .grads are not set, meaning that they cannot be optimized separately (which is why we cannot support multiple parameter groups)`).
|
||||||
|
|
||||||
Sections that were moved:
|
Sections that were moved:
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -103,6 +104,7 @@ from .trainer_utils import (
|
|||||||
BestRun,
|
BestRun,
|
||||||
EvalLoopOutput,
|
EvalLoopOutput,
|
||||||
EvalPrediction,
|
EvalPrediction,
|
||||||
|
FSDPOption,
|
||||||
HPSearchBackend,
|
HPSearchBackend,
|
||||||
HubStrategy,
|
HubStrategy,
|
||||||
IntervalStrategy,
|
IntervalStrategy,
|
||||||
@@ -340,6 +342,10 @@ class Trainer:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
|
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
|
||||||
)
|
)
|
||||||
|
if len(args.fsdp) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
|
||||||
|
)
|
||||||
|
|
||||||
if args.local_rank == -1:
|
if args.local_rank == -1:
|
||||||
raise ValueError("Using sharded DDP only works in distributed training.")
|
raise ValueError("Using sharded DDP only works in distributed training.")
|
||||||
@@ -357,6 +363,30 @@ class Trainer:
|
|||||||
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
|
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
|
||||||
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3
|
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3
|
||||||
|
|
||||||
|
self.fsdp = None
|
||||||
|
if len(args.fsdp) > 0:
|
||||||
|
if args.deepspeed:
|
||||||
|
raise ValueError(
|
||||||
|
"Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
|
||||||
|
)
|
||||||
|
if args.local_rank == -1:
|
||||||
|
raise ValueError("Using fsdp only works in distributed training.")
|
||||||
|
|
||||||
|
# dep_version_check("torch>=1.12.0.dev20220418+cu113")
|
||||||
|
# Would have to update setup.py with torch>=1.12.0.dev20220418+cu113
|
||||||
|
# which isn't ideally given that it's a dev version
|
||||||
|
# and it will force people not using FSDP to also use torch>=1.12.0.dev20220418+cu113
|
||||||
|
# below is the current alternative.
|
||||||
|
if version.parse(torch.__version__) < version.parse("1.12.0.dev20220418+cu113"):
|
||||||
|
raise ValueError("FSDP requires PyTorch >= 1.12.0.dev20220418+cu113")
|
||||||
|
|
||||||
|
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
|
||||||
|
|
||||||
|
if FSDPOption.FULL_SHARD in args.fsdp:
|
||||||
|
self.fsdp = ShardingStrategy.FULL_SHARD
|
||||||
|
elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
|
||||||
|
self.fsdp = ShardingStrategy.SHARD_GRAD_OP
|
||||||
|
|
||||||
# one place to sort out whether to place the model on device or not
|
# one place to sort out whether to place the model on device or not
|
||||||
# postpone switching model to cuda when:
|
# postpone switching model to cuda when:
|
||||||
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
|
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
|
||||||
@@ -364,12 +394,14 @@ class Trainer:
|
|||||||
# and we only use deepspeed for training at the moment
|
# and we only use deepspeed for training at the moment
|
||||||
# 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
|
# 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
|
||||||
# 4. Sharded DDP - same as MP
|
# 4. Sharded DDP - same as MP
|
||||||
|
# 5. FSDP - same as MP
|
||||||
self.place_model_on_device = args.place_model_on_device
|
self.place_model_on_device = args.place_model_on_device
|
||||||
if (
|
if (
|
||||||
self.is_model_parallel
|
self.is_model_parallel
|
||||||
or args.deepspeed
|
or args.deepspeed
|
||||||
or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
|
or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
|
||||||
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
|
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
|
||||||
|
or (self.fsdp is not None)
|
||||||
):
|
):
|
||||||
self.place_model_on_device = False
|
self.place_model_on_device = False
|
||||||
|
|
||||||
@@ -398,11 +430,11 @@ class Trainer:
|
|||||||
"Passing a `model_init` is incompatible with providing the `optimizers` argument. "
|
"Passing a `model_init` is incompatible with providing the `optimizers` argument. "
|
||||||
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||||
)
|
)
|
||||||
if (self.sharded_ddp is not None or args.deepspeed) and (
|
if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and (
|
||||||
self.optimizer is not None or self.lr_scheduler is not None
|
self.optimizer is not None or self.lr_scheduler is not None
|
||||||
):
|
):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Passing `optimizers` is not allowed if Fairscale or Deepspeed is enabled."
|
"Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
|
||||||
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||||
)
|
)
|
||||||
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
||||||
@@ -450,6 +482,11 @@ class Trainer:
|
|||||||
self.use_amp = False
|
self.use_amp = False
|
||||||
|
|
||||||
if args.fp16 or args.bf16:
|
if args.fp16 or args.bf16:
|
||||||
|
if self.fsdp is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Mixed precision is currently not supported for FSDP."
|
||||||
|
"Please do not set arguments related to `mixed_precision`"
|
||||||
|
)
|
||||||
if args.half_precision_backend == "auto":
|
if args.half_precision_backend == "auto":
|
||||||
if _is_native_amp_available:
|
if _is_native_amp_available:
|
||||||
args.half_precision_backend = "amp"
|
args.half_precision_backend = "amp"
|
||||||
@@ -1102,6 +1139,33 @@ class Trainer:
|
|||||||
cpu_offload=cpu_offload,
|
cpu_offload=cpu_offload,
|
||||||
).to(self.args.device)
|
).to(self.args.device)
|
||||||
|
|
||||||
|
# Distributed training using PyTorch FSDP
|
||||||
|
if self.fsdp is not None:
|
||||||
|
# PyTorch FSDP!
|
||||||
|
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
||||||
|
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
||||||
|
from torch.distributed.fsdp.wrap import default_auto_wrap_policy
|
||||||
|
|
||||||
|
if FSDPOption.OFFLOAD in self.args.fsdp:
|
||||||
|
cpu_offload = CPUOffload(offload_params=True)
|
||||||
|
else:
|
||||||
|
cpu_offload = CPUOffload(offload_params=False)
|
||||||
|
|
||||||
|
auto_wrap_policy = None
|
||||||
|
if FSDPOption.AUTO_WRAP in self.args.fsdp:
|
||||||
|
if self.args.fsdp_min_num_params > 0:
|
||||||
|
auto_wrap_policy = functools.partial(
|
||||||
|
default_auto_wrap_policy, min_num_params=self.args.fsdp_min_num_params
|
||||||
|
)
|
||||||
|
|
||||||
|
if type(model) != FSDP:
|
||||||
|
# XXX: Breaking the self.model convention but I see no way around it for now.
|
||||||
|
self.model = model = FSDP(
|
||||||
|
model, sharding_strategy=self.fsdp, cpu_offload=cpu_offload, auto_wrap_policy=auto_wrap_policy
|
||||||
|
)
|
||||||
|
if FSDPOption.OFFLOAD not in self.args.fsdp:
|
||||||
|
model.to(self.args.device)
|
||||||
|
|
||||||
elif is_sagemaker_dp_enabled():
|
elif is_sagemaker_dp_enabled():
|
||||||
model = nn.parallel.DistributedDataParallel(
|
model = nn.parallel.DistributedDataParallel(
|
||||||
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
|
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
|
||||||
@@ -1253,7 +1317,10 @@ class Trainer:
|
|||||||
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
||||||
|
|
||||||
delay_optimizer_creation = (
|
delay_optimizer_creation = (
|
||||||
self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE or is_sagemaker_mp_enabled()
|
self.sharded_ddp is not None
|
||||||
|
and self.sharded_ddp != ShardedDDPOption.SIMPLE
|
||||||
|
or is_sagemaker_mp_enabled()
|
||||||
|
or self.fsdp is not None
|
||||||
)
|
)
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
|
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
|
||||||
@@ -2138,7 +2205,9 @@ class Trainer:
|
|||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
self._save(output_dir, state_dict=state_dict)
|
self._save(output_dir, state_dict=state_dict)
|
||||||
elif (
|
elif (
|
||||||
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
|
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
|
||||||
|
or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
|
||||||
|
or self.fsdp is not None
|
||||||
):
|
):
|
||||||
state_dict = self.model.state_dict()
|
state_dict = self.model.state_dict()
|
||||||
|
|
||||||
|
|||||||
@@ -582,3 +582,10 @@ class ShardedDDPOption(ExplicitEnum):
|
|||||||
ZERO_DP_3 = "zero_dp_3"
|
ZERO_DP_3 = "zero_dp_3"
|
||||||
OFFLOAD = "offload"
|
OFFLOAD = "offload"
|
||||||
AUTO_WRAP = "auto_wrap"
|
AUTO_WRAP = "auto_wrap"
|
||||||
|
|
||||||
|
|
||||||
|
class FSDPOption(ExplicitEnum):
|
||||||
|
FULL_SHARD = "full_shard"
|
||||||
|
SHARD_GRAD_OP = "shard_grad_op"
|
||||||
|
OFFLOAD = "offload"
|
||||||
|
AUTO_WRAP = "auto_wrap"
|
||||||
|
|||||||
@@ -23,7 +23,14 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from .debug_utils import DebugOption
|
from .debug_utils import DebugOption
|
||||||
from .trainer_utils import EvaluationStrategy, HubStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption
|
from .trainer_utils import (
|
||||||
|
EvaluationStrategy,
|
||||||
|
FSDPOption,
|
||||||
|
HubStrategy,
|
||||||
|
IntervalStrategy,
|
||||||
|
SchedulerType,
|
||||||
|
ShardedDDPOption,
|
||||||
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
ExplicitEnum,
|
ExplicitEnum,
|
||||||
cached_property,
|
cached_property,
|
||||||
@@ -331,6 +338,18 @@ class TrainingArguments:
|
|||||||
|
|
||||||
If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty
|
If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty
|
||||||
list for `False` and `["simple"]` for `True`.
|
list for `False` and `["simple"]` for `True`.
|
||||||
|
fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `False`):
|
||||||
|
Use PyTorch Distributed Parallel Training (in distributed training only).
|
||||||
|
|
||||||
|
A list of options along the following:
|
||||||
|
|
||||||
|
- `"full_shard"`: Shard parameters, gradients and optimizer states.
|
||||||
|
- `"shard_grad_op"`: Shard optimizer states and gradients.
|
||||||
|
- `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and
|
||||||
|
`"shard_grad_op"`).
|
||||||
|
- `"auto_wrap"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`.
|
||||||
|
fsdp_min_num_params (`int`, *optional*, defaults to `0`):
|
||||||
|
FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is passed).
|
||||||
deepspeed (`str` or `dict`, *optional*):
|
deepspeed (`str` or `dict`, *optional*):
|
||||||
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
|
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
|
||||||
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
|
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
|
||||||
@@ -674,10 +693,25 @@ class TrainingArguments:
|
|||||||
metadata={
|
metadata={
|
||||||
"help": "Whether or not to use sharded DDP training (in distributed training only). The base option "
|
"help": "Whether or not to use sharded DDP training (in distributed training only). The base option "
|
||||||
"should be `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` "
|
"should be `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` "
|
||||||
"like this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or "
|
"like this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or `zero_dp_3` "
|
||||||
"with the same syntax: zero_dp_2 auto_wrap` or `zero_dp_3 auto_wrap`.",
|
"with the same syntax: zero_dp_2 auto_wrap` or `zero_dp_3 auto_wrap`.",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
fsdp: str = field(
|
||||||
|
default="",
|
||||||
|
metadata={
|
||||||
|
"help": "Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training only). The base option "
|
||||||
|
"should be `full_shard` or `shard_grad_op` and you can add CPU-offload to `full_shard` or `shard_grad_op` "
|
||||||
|
"like this: full_shard offload` or `shard_grad_op offload`. You can add auto-wrap to `full_shard` or `shard_grad_op` "
|
||||||
|
"with the same syntax: full_shard auto_wrap` or `shard_grad_op auto_wrap`.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
fsdp_min_num_params: int = field(
|
||||||
|
default=0,
|
||||||
|
metadata={
|
||||||
|
"help": "FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is passed)."
|
||||||
|
},
|
||||||
|
)
|
||||||
deepspeed: Optional[str] = field(
|
deepspeed: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
@@ -931,6 +965,21 @@ class TrainingArguments:
|
|||||||
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
|
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
|
||||||
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
|
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
|
||||||
|
|
||||||
|
if isinstance(self.fsdp, bool):
|
||||||
|
self.fsdp = "full_shard" if self.fsdp else ""
|
||||||
|
if isinstance(self.fsdp, str):
|
||||||
|
self.fsdp = [FSDPOption(s) for s in self.fsdp.split()]
|
||||||
|
if self.fsdp == [FSDPOption.OFFLOAD]:
|
||||||
|
raise ValueError(
|
||||||
|
"`--fsdp offload` can't work on its own. It needs to be added to `--fsdp full_shard` or "
|
||||||
|
'`--fsdp shard_grad_op`. For example, `--fsdp "full_shard offload"`.'
|
||||||
|
)
|
||||||
|
elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.sharded_ddp:
|
||||||
|
raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.")
|
||||||
|
|
||||||
|
if len(self.fsdp) == 0 and self.fsdp_min_num_params > 0:
|
||||||
|
warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.")
|
||||||
|
|
||||||
if self.tpu_metrics_debug:
|
if self.tpu_metrics_debug:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--debug tpu_metrics_debug` instead",
|
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--debug tpu_metrics_debug` instead",
|
||||||
|
|||||||
Reference in New Issue
Block a user