Add FSDP config for CPU RAM efficient loading through accelerate (#30002)
* Add FSDP config for CPU RAM efficient loading * Style fix * Update src/transformers/training_args.py Co-authored-by: Zach Mueller <muellerzr@gmail.com> * Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Add sync_module_states and cpu_ram_efficient_loading validation logic * Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Style --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -513,6 +513,11 @@ class TrainingArguments:
|
|||||||
- sync_module_states (`bool`, *optional*, defaults to `True`)
|
- sync_module_states (`bool`, *optional*, defaults to `True`)
|
||||||
If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to
|
If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to
|
||||||
ensure they are the same across all ranks after initialization
|
ensure they are the same across all ranks after initialization
|
||||||
|
- cpu_ram_efficient_loading (`bool`, *optional*, defaults to `False`)
|
||||||
|
If `"True"`, only the first process loads the pretrained model checkpoint while all other processes
|
||||||
|
have empty weights. When this setting as `"True"`, `sync_module_states` also must to be `"True"`,
|
||||||
|
otherwise all the processes except the main process would have random weights leading to unexpected
|
||||||
|
behaviour during training.
|
||||||
- activation_checkpointing (`bool`, *optional*, defaults to `False`):
|
- activation_checkpointing (`bool`, *optional*, defaults to `False`):
|
||||||
If `"True"`, activation checkpointing is a technique to reduce memory usage by clearing activations of
|
If `"True"`, activation checkpointing is a technique to reduce memory usage by clearing activations of
|
||||||
certain layers and recomputing them during a backward pass. Effectively, this trades extra
|
certain layers and recomputing them during a backward pass. Effectively, this trades extra
|
||||||
@@ -1826,7 +1831,18 @@ class TrainingArguments:
|
|||||||
prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH")
|
prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH")
|
||||||
os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper()
|
os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper()
|
||||||
os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefetch", "false")
|
os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefetch", "false")
|
||||||
os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true")
|
|
||||||
|
sync_module_states = self.fsdp_config.get("sync_module_states", "true")
|
||||||
|
cpu_ram_efficient_loading = self.fsdp_config.get("cpu_ram_efficient_loading", "false")
|
||||||
|
|
||||||
|
if str(sync_module_states).lower() == "false" and str(cpu_ram_efficient_loading).lower() == "true":
|
||||||
|
# In this case, all the processes except the main process would have random weights leading
|
||||||
|
# to unexpected behaviour during training, thus throwing error here to prevent it.
|
||||||
|
raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`')
|
||||||
|
|
||||||
|
os.environ[f"{prefix}SYNC_MODULE_STATES"] = sync_module_states
|
||||||
|
os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading
|
||||||
|
|
||||||
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true")
|
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true")
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
|
|||||||
@@ -144,6 +144,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
"limit_all_gathers": "False",
|
"limit_all_gathers": "False",
|
||||||
"use_orig_params": "True",
|
"use_orig_params": "True",
|
||||||
"sync_module_states": "True",
|
"sync_module_states": "True",
|
||||||
|
"cpu_ram_efficient_loading": "True",
|
||||||
"activation_checkpointing": "False",
|
"activation_checkpointing": "False",
|
||||||
"min_num_params": 1,
|
"min_num_params": 1,
|
||||||
}
|
}
|
||||||
@@ -208,6 +209,9 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertEqual(os.environ[f"{prefix}FORWARD_PREFETCH"], fsdp_config["forward_prefetch"])
|
self.assertEqual(os.environ[f"{prefix}FORWARD_PREFETCH"], fsdp_config["forward_prefetch"])
|
||||||
self.assertEqual(os.environ[f"{prefix}USE_ORIG_PARAMS"], fsdp_config["use_orig_params"])
|
self.assertEqual(os.environ[f"{prefix}USE_ORIG_PARAMS"], fsdp_config["use_orig_params"])
|
||||||
self.assertEqual(os.environ[f"{prefix}SYNC_MODULE_STATES"], fsdp_config["sync_module_states"])
|
self.assertEqual(os.environ[f"{prefix}SYNC_MODULE_STATES"], fsdp_config["sync_module_states"])
|
||||||
|
self.assertEqual(
|
||||||
|
os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"], fsdp_config["cpu_ram_efficient_loading"]
|
||||||
|
)
|
||||||
self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true")
|
self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true")
|
||||||
|
|
||||||
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
|
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
|
||||||
|
|||||||
Reference in New Issue
Block a user