move fsdp handling to accelerate (#23158)

* mixed precision support via accelerate

* fix issues

* fix for the sharded ddp case

* fix flax and tf failing tests

* `refactor the place to create `Accelerator` object

* move ddp prep to accelerate

* fix 😅

* resolving comments

* move fsdp handling to accelerate

* fixex

* fix saving
This commit is contained in:
Sourab Mangrulkar
2023-05-31 14:10:46 +05:30
committed by GitHub
parent 015829e6c4
commit 0b774074a5
2 changed files with 105 additions and 116 deletions

View File

@@ -442,7 +442,7 @@ class TrainingArguments:
- `"backward_pre"` : Prefetches the next set of parameters before the current set of parameter's
gradient
computation.
- `"backward_pos"` : This prefetches the next set of parameters after the current set of
- `"backward_post"` : This prefetches the next set of parameters after the current set of
parameters
gradient computation.
- fsdp_forward_prefetch (`bool`, *optional*, defaults to `False`)
@@ -1504,6 +1504,32 @@ class TrainingArguments:
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")
# accelerate integration for FSDP
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
os.environ["ACCELERATE_USE_FSDP"] = "true"
from accelerate.utils.constants import (
FSDP_AUTO_WRAP_POLICY,
FSDP_SHARDING_STRATEGY,
)
for fsdp_option in self.fsdp:
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
# set environment variable for FSDP sharding strategy
os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1)
elif fsdp_option == FSDPOption.OFFLOAD:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
elif fsdp_option == FSDPOption.AUTO_WRAP:
if self.fsdp_config["fsdp_min_num_params"] > 0:
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
)
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()
if self.tpu_metrics_debug:
warnings.warn(
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"