From f16caf44bb1606652ac6c7c4ad4bf44973d4e545 Mon Sep 17 00:00:00 2001 From: Howard Liberty Date: Mon, 22 Apr 2024 05:15:28 -0700 Subject: [PATCH] Add FSDP config for CPU RAM efficient loading through accelerate (#30002) * Add FSDP config for CPU RAM efficient loading * Style fix * Update src/transformers/training_args.py Co-authored-by: Zach Mueller * Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Add sync_module_states and cpu_ram_efficient_loading validation logic * Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Style --------- Co-authored-by: Zach Mueller Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/training_args.py | 18 +++++++++++++++++- tests/fsdp/test_fsdp.py | 4 ++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 0b493e5d1d..5e81c22db9 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -513,6 +513,11 @@ class TrainingArguments: - sync_module_states (`bool`, *optional*, defaults to `True`) If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to ensure they are the same across all ranks after initialization + - cpu_ram_efficient_loading (`bool`, *optional*, defaults to `False`) + If `"True"`, only the first process loads the pretrained model checkpoint while all other processes + have empty weights. When this setting as `"True"`, `sync_module_states` also must to be `"True"`, + otherwise all the processes except the main process would have random weights leading to unexpected + behaviour during training. - activation_checkpointing (`bool`, *optional*, defaults to `False`): If `"True"`, activation checkpointing is a technique to reduce memory usage by clearing activations of certain layers and recomputing them during a backward pass. Effectively, this trades extra @@ -1826,7 +1831,18 @@ class TrainingArguments: prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH") os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper() os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefetch", "false") - os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true") + + sync_module_states = self.fsdp_config.get("sync_module_states", "true") + cpu_ram_efficient_loading = self.fsdp_config.get("cpu_ram_efficient_loading", "false") + + if str(sync_module_states).lower() == "false" and str(cpu_ram_efficient_loading).lower() == "true": + # In this case, all the processes except the main process would have random weights leading + # to unexpected behaviour during training, thus throwing error here to prevent it. + raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`') + + os.environ[f"{prefix}SYNC_MODULE_STATES"] = sync_module_states + os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading + os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true") if is_accelerate_available(): diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index 71293f1601..9ae55ecdec 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -144,6 +144,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): "limit_all_gathers": "False", "use_orig_params": "True", "sync_module_states": "True", + "cpu_ram_efficient_loading": "True", "activation_checkpointing": "False", "min_num_params": 1, } @@ -208,6 +209,9 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): self.assertEqual(os.environ[f"{prefix}FORWARD_PREFETCH"], fsdp_config["forward_prefetch"]) self.assertEqual(os.environ[f"{prefix}USE_ORIG_PARAMS"], fsdp_config["use_orig_params"]) self.assertEqual(os.environ[f"{prefix}SYNC_MODULE_STATES"], fsdp_config["sync_module_states"]) + self.assertEqual( + os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"], fsdp_config["cpu_ram_efficient_loading"] + ) self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true") @parameterized.expand(params, name_func=_parameterized_custom_name_func)