move fsdp handling to accelerate (#23158)
* mixed precision support via accelerate
* fix issues
* fix for the sharded ddp case
* fix flax and tf failing tests
* `refactor the place to create `Accelerator` object
* move ddp prep to accelerate
* fix 😅
* resolving comments
* move fsdp handling to accelerate
* fixex
* fix saving
This commit is contained in:
committed by
GitHub
parent
015829e6c4
commit
0b774074a5
@@ -343,6 +343,12 @@ class Trainer:
|
|||||||
# create accelerator object
|
# create accelerator object
|
||||||
self.accelerator = Accelerator()
|
self.accelerator = Accelerator()
|
||||||
|
|
||||||
|
# post accelerator creation setup
|
||||||
|
if getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
|
||||||
|
fsdp_plugin = self.accelerator.state.fsdp_plugin
|
||||||
|
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
|
||||||
|
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)
|
||||||
|
|
||||||
# memory metrics - must set up as early as possible
|
# memory metrics - must set up as early as possible
|
||||||
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
|
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
|
||||||
self._memory_tracker.start()
|
self._memory_tracker.start()
|
||||||
@@ -464,7 +470,7 @@ class Trainer:
|
|||||||
self.fsdp = ShardingStrategy.NO_SHARD
|
self.fsdp = ShardingStrategy.NO_SHARD
|
||||||
|
|
||||||
self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
|
self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
|
||||||
if "backward_prefetch" in self.args.fsdp_config and "backward_pos" in self.args.fsdp_config.get(
|
if "backward_prefetch" in self.args.fsdp_config and "backward_post" in self.args.fsdp_config.get(
|
||||||
"backward_prefetch", []
|
"backward_prefetch", []
|
||||||
):
|
):
|
||||||
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST
|
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST
|
||||||
@@ -1479,114 +1485,58 @@ class Trainer:
|
|||||||
cpu_offload=cpu_offload,
|
cpu_offload=cpu_offload,
|
||||||
).to(self.args.device)
|
).to(self.args.device)
|
||||||
# Distributed training using PyTorch FSDP
|
# Distributed training using PyTorch FSDP
|
||||||
elif self.fsdp is not None:
|
elif self.fsdp is not None and self.args.fsdp_config["xla"]:
|
||||||
if not self.args.fsdp_config["xla"]:
|
try:
|
||||||
# PyTorch FSDP!
|
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
|
||||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
|
from torch_xla.distributed.fsdp import checkpoint_module
|
||||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
from torch_xla.distributed.fsdp.wrap import (
|
||||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
|
size_based_auto_wrap_policy,
|
||||||
|
transformer_auto_wrap_policy,
|
||||||
if FSDPOption.OFFLOAD in self.args.fsdp:
|
|
||||||
cpu_offload = CPUOffload(offload_params=True)
|
|
||||||
else:
|
|
||||||
cpu_offload = CPUOffload(offload_params=False)
|
|
||||||
|
|
||||||
auto_wrap_policy = None
|
|
||||||
|
|
||||||
if FSDPOption.AUTO_WRAP in self.args.fsdp:
|
|
||||||
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
|
|
||||||
auto_wrap_policy = functools.partial(
|
|
||||||
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
|
|
||||||
)
|
|
||||||
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
|
|
||||||
transformer_cls_to_wrap = set()
|
|
||||||
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
|
|
||||||
transformer_cls = get_module_class_from_name(model, layer_class)
|
|
||||||
if transformer_cls is None:
|
|
||||||
raise Exception("Could not find the transformer layer class to wrap in the model.")
|
|
||||||
else:
|
|
||||||
transformer_cls_to_wrap.add(transformer_cls)
|
|
||||||
auto_wrap_policy = functools.partial(
|
|
||||||
transformer_auto_wrap_policy,
|
|
||||||
# Transformer layer class to wrap
|
|
||||||
transformer_layer_cls=transformer_cls_to_wrap,
|
|
||||||
)
|
|
||||||
mixed_precision_policy = None
|
|
||||||
dtype = None
|
|
||||||
if self.args.fp16:
|
|
||||||
dtype = torch.float16
|
|
||||||
elif self.args.bf16:
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
if dtype is not None:
|
|
||||||
mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
|
|
||||||
if type(model) != FSDP:
|
|
||||||
# XXX: Breaking the self.model convention but I see no way around it for now.
|
|
||||||
signature = inspect.signature(FSDP.__init__).parameters.keys()
|
|
||||||
kwargs = {}
|
|
||||||
for arg in ["limit_all_gathers", "forward_prefetch", "backward_prefetch"]:
|
|
||||||
if arg in signature:
|
|
||||||
kwargs[arg] = getattr(self, arg)
|
|
||||||
self.model = model = FSDP(
|
|
||||||
model,
|
|
||||||
sharding_strategy=self.fsdp,
|
|
||||||
cpu_offload=cpu_offload,
|
|
||||||
auto_wrap_policy=auto_wrap_policy,
|
|
||||||
mixed_precision=mixed_precision_policy,
|
|
||||||
device_id=self.args.device,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
|
|
||||||
from torch_xla.distributed.fsdp import checkpoint_module
|
|
||||||
from torch_xla.distributed.fsdp.wrap import (
|
|
||||||
size_based_auto_wrap_policy,
|
|
||||||
transformer_auto_wrap_policy,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
|
|
||||||
auto_wrap_policy = None
|
|
||||||
auto_wrapper_callable = None
|
|
||||||
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
|
|
||||||
auto_wrap_policy = functools.partial(
|
|
||||||
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
|
|
||||||
)
|
|
||||||
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
|
|
||||||
transformer_cls_to_wrap = set()
|
|
||||||
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
|
|
||||||
transformer_cls = get_module_class_from_name(model, layer_class)
|
|
||||||
if transformer_cls is None:
|
|
||||||
raise Exception("Could not find the transformer layer class to wrap in the model.")
|
|
||||||
else:
|
|
||||||
transformer_cls_to_wrap.add(transformer_cls)
|
|
||||||
auto_wrap_policy = functools.partial(
|
|
||||||
transformer_auto_wrap_policy,
|
|
||||||
# Transformer layer class to wrap
|
|
||||||
transformer_layer_cls=transformer_cls_to_wrap,
|
|
||||||
)
|
|
||||||
fsdp_kwargs = self.args.xla_fsdp_config
|
|
||||||
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
|
|
||||||
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
|
|
||||||
def auto_wrapper_callable(m, *args, **kwargs):
|
|
||||||
return FSDP(checkpoint_module(m), *args, **kwargs)
|
|
||||||
|
|
||||||
# Wrap the base model with an outer FSDP wrapper
|
|
||||||
self.model = model = FSDP(
|
|
||||||
model,
|
|
||||||
auto_wrap_policy=auto_wrap_policy,
|
|
||||||
auto_wrapper_callable=auto_wrapper_callable,
|
|
||||||
**fsdp_kwargs,
|
|
||||||
)
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
|
||||||
|
auto_wrap_policy = None
|
||||||
|
auto_wrapper_callable = None
|
||||||
|
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
|
||||||
|
auto_wrap_policy = functools.partial(
|
||||||
|
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
|
||||||
|
)
|
||||||
|
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
|
||||||
|
transformer_cls_to_wrap = set()
|
||||||
|
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
|
||||||
|
transformer_cls = get_module_class_from_name(model, layer_class)
|
||||||
|
if transformer_cls is None:
|
||||||
|
raise Exception("Could not find the transformer layer class to wrap in the model.")
|
||||||
|
else:
|
||||||
|
transformer_cls_to_wrap.add(transformer_cls)
|
||||||
|
auto_wrap_policy = functools.partial(
|
||||||
|
transformer_auto_wrap_policy,
|
||||||
|
# Transformer layer class to wrap
|
||||||
|
transformer_layer_cls=transformer_cls_to_wrap,
|
||||||
|
)
|
||||||
|
fsdp_kwargs = self.args.xla_fsdp_config
|
||||||
|
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
|
||||||
|
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
|
||||||
|
def auto_wrapper_callable(m, *args, **kwargs):
|
||||||
|
return FSDP(checkpoint_module(m), *args, **kwargs)
|
||||||
|
|
||||||
# Patch `xm.optimizer_step` should not reduce gradients in this case,
|
# Wrap the base model with an outer FSDP wrapper
|
||||||
# as FSDP does not need gradient reduction over sharded parameters.
|
self.model = model = FSDP(
|
||||||
def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
|
model,
|
||||||
loss = optimizer.step(**optimizer_args)
|
auto_wrap_policy=auto_wrap_policy,
|
||||||
if barrier:
|
auto_wrapper_callable=auto_wrapper_callable,
|
||||||
xm.mark_step()
|
**fsdp_kwargs,
|
||||||
return loss
|
)
|
||||||
|
|
||||||
xm.optimizer_step = patched_optimizer_step
|
# Patch `xm.optimizer_step` should not reduce gradients in this case,
|
||||||
|
# as FSDP does not need gradient reduction over sharded parameters.
|
||||||
|
def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
|
||||||
|
loss = optimizer.step(**optimizer_args)
|
||||||
|
if barrier:
|
||||||
|
xm.mark_step()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
xm.optimizer_step = patched_optimizer_step
|
||||||
elif is_sagemaker_dp_enabled():
|
elif is_sagemaker_dp_enabled():
|
||||||
model = nn.parallel.DistributedDataParallel(
|
model = nn.parallel.DistributedDataParallel(
|
||||||
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
|
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
|
||||||
@@ -1796,17 +1746,26 @@ class Trainer:
|
|||||||
if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
|
if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
|
||||||
self._load_from_checkpoint(resume_from_checkpoint, model)
|
self._load_from_checkpoint(resume_from_checkpoint, model)
|
||||||
|
|
||||||
# for the rest of this function `model` is the outside model, whether it was wrapped or not
|
# as the model is wrapped, don't use `accelerator.prepare`
|
||||||
if model is not self.model:
|
# this is for unhandled cases such as
|
||||||
self.model_wrapped = model
|
# Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
|
||||||
|
use_accelerator_prepare = True if model is self.model else False
|
||||||
|
|
||||||
if delay_optimizer_creation:
|
if delay_optimizer_creation:
|
||||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||||
|
|
||||||
# prepare using `accelerator` prepare
|
# prepare using `accelerator` prepare
|
||||||
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
|
if use_accelerator_prepare:
|
||||||
self.model, self.optimizer, self.lr_scheduler
|
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
|
||||||
)
|
self.model, self.optimizer, self.lr_scheduler
|
||||||
|
)
|
||||||
|
|
||||||
|
if getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
# for the rest of this function `model` is the outside model, whether it was wrapped or not
|
||||||
|
if model is not self.model:
|
||||||
|
self.model_wrapped = model
|
||||||
|
|
||||||
# Check if saved optimizer or scheduler states exist
|
# Check if saved optimizer or scheduler states exist
|
||||||
self._load_optimizer_and_scheduler(resume_from_checkpoint)
|
self._load_optimizer_and_scheduler(resume_from_checkpoint)
|
||||||
@@ -2894,11 +2853,15 @@ class Trainer:
|
|||||||
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
|
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
|
||||||
or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
|
or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
|
||||||
or self.fsdp is not None
|
or self.fsdp is not None
|
||||||
|
or getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
||||||
):
|
):
|
||||||
state_dict = self.model.state_dict()
|
if getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
|
||||||
|
self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir)
|
||||||
|
else:
|
||||||
|
state_dict = self.model.state_dict()
|
||||||
|
|
||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
self._save(output_dir, state_dict=state_dict)
|
self._save(output_dir, state_dict=state_dict)
|
||||||
elif self.deepspeed:
|
elif self.deepspeed:
|
||||||
# this takes care of everything as long as we aren't under zero3
|
# this takes care of everything as long as we aren't under zero3
|
||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
|
|||||||
@@ -442,7 +442,7 @@ class TrainingArguments:
|
|||||||
- `"backward_pre"` : Prefetches the next set of parameters before the current set of parameter's
|
- `"backward_pre"` : Prefetches the next set of parameters before the current set of parameter's
|
||||||
gradient
|
gradient
|
||||||
computation.
|
computation.
|
||||||
- `"backward_pos"` : This prefetches the next set of parameters after the current set of
|
- `"backward_post"` : This prefetches the next set of parameters after the current set of
|
||||||
parameter’s
|
parameter’s
|
||||||
gradient computation.
|
gradient computation.
|
||||||
- fsdp_forward_prefetch (`bool`, *optional*, defaults to `False`)
|
- fsdp_forward_prefetch (`bool`, *optional*, defaults to `False`)
|
||||||
@@ -1504,6 +1504,32 @@ class TrainingArguments:
|
|||||||
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
|
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
|
||||||
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")
|
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")
|
||||||
|
|
||||||
|
# accelerate integration for FSDP
|
||||||
|
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
|
||||||
|
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
||||||
|
from accelerate.utils.constants import (
|
||||||
|
FSDP_AUTO_WRAP_POLICY,
|
||||||
|
FSDP_SHARDING_STRATEGY,
|
||||||
|
)
|
||||||
|
|
||||||
|
for fsdp_option in self.fsdp:
|
||||||
|
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
|
||||||
|
# set environment variable for FSDP sharding strategy
|
||||||
|
os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1)
|
||||||
|
elif fsdp_option == FSDPOption.OFFLOAD:
|
||||||
|
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
|
||||||
|
elif fsdp_option == FSDPOption.AUTO_WRAP:
|
||||||
|
if self.fsdp_config["fsdp_min_num_params"] > 0:
|
||||||
|
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
|
||||||
|
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
|
||||||
|
elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
|
||||||
|
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
|
||||||
|
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
|
||||||
|
)
|
||||||
|
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
|
||||||
|
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
|
||||||
|
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()
|
||||||
|
|
||||||
if self.tpu_metrics_debug:
|
if self.tpu_metrics_debug:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
|
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
|
||||||
|
|||||||
Reference in New Issue
Block a user