move fsdp handling to accelerate (#23158)
* mixed precision support via accelerate
* fix issues
* fix for the sharded ddp case
* fix flax and tf failing tests
* `refactor the place to create `Accelerator` object
* move ddp prep to accelerate
* fix 😅
* resolving comments
* move fsdp handling to accelerate
* fixex
* fix saving
This commit is contained in:
committed by
GitHub
parent
015829e6c4
commit
0b774074a5
@@ -442,7 +442,7 @@ class TrainingArguments:
|
||||
- `"backward_pre"` : Prefetches the next set of parameters before the current set of parameter's
|
||||
gradient
|
||||
computation.
|
||||
- `"backward_pos"` : This prefetches the next set of parameters after the current set of
|
||||
- `"backward_post"` : This prefetches the next set of parameters after the current set of
|
||||
parameter’s
|
||||
gradient computation.
|
||||
- fsdp_forward_prefetch (`bool`, *optional*, defaults to `False`)
|
||||
@@ -1504,6 +1504,32 @@ class TrainingArguments:
|
||||
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
|
||||
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")
|
||||
|
||||
# accelerate integration for FSDP
|
||||
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
|
||||
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
||||
from accelerate.utils.constants import (
|
||||
FSDP_AUTO_WRAP_POLICY,
|
||||
FSDP_SHARDING_STRATEGY,
|
||||
)
|
||||
|
||||
for fsdp_option in self.fsdp:
|
||||
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
|
||||
# set environment variable for FSDP sharding strategy
|
||||
os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1)
|
||||
elif fsdp_option == FSDPOption.OFFLOAD:
|
||||
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
|
||||
elif fsdp_option == FSDPOption.AUTO_WRAP:
|
||||
if self.fsdp_config["fsdp_min_num_params"] > 0:
|
||||
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
|
||||
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
|
||||
elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
|
||||
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
|
||||
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
|
||||
)
|
||||
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
|
||||
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
|
||||
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()
|
||||
|
||||
if self.tpu_metrics_debug:
|
||||
warnings.warn(
|
||||
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
|
||||
|
||||
Reference in New Issue
Block a user