fsdp fixes and enhancements (#24980)

* fix fsdp prepare to remove the warnings and fix excess memory usage

* Update training_args.py

* parity for FSDP+XLA

* Update trainer.py
This commit is contained in:
Sourab Mangrulkar
2023-07-21 17:52:48 +05:30
committed by GitHub
parent ec3dfe5e24
commit f4eb459ef2
3 changed files with 14 additions and 5 deletions

View File

@@ -1567,6 +1567,7 @@ class TrainingArguments:
elif fsdp_option == FSDPOption.OFFLOAD:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
elif fsdp_option == FSDPOption.AUTO_WRAP:
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
if self.fsdp_config["fsdp_min_num_params"] > 0:
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
@@ -1574,7 +1575,6 @@ class TrainingArguments:
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
)
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()