fsdp fixes and enhancements (#24980)
* fix fsdp prepare to remove the warnings and fix excess memory usage * Update training_args.py * parity for FSDP+XLA * Update trainer.py
This commit is contained in:
committed by
GitHub
parent
ec3dfe5e24
commit
f4eb459ef2
@@ -1567,6 +1567,7 @@ class TrainingArguments:
|
||||
elif fsdp_option == FSDPOption.OFFLOAD:
|
||||
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
|
||||
elif fsdp_option == FSDPOption.AUTO_WRAP:
|
||||
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
|
||||
if self.fsdp_config["fsdp_min_num_params"] > 0:
|
||||
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
|
||||
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
|
||||
@@ -1574,7 +1575,6 @@ class TrainingArguments:
|
||||
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
|
||||
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
|
||||
)
|
||||
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
|
||||
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
|
||||
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user