fix resuming from ckpt when using FSDP with FULL_STATE_DICT (#27891)

* fix resuming from ckpt when suing FSDP with FULL_STATE_DICT

* update tests

* fix tests
This commit is contained in:
Sourab Mangrulkar
2023-12-16 19:41:43 +05:30
committed by GitHub
parent ebfdb9ca62
commit 238d2e3c44
2 changed files with 23 additions and 4 deletions

View File

@@ -41,6 +41,7 @@ from transformers.utils import is_accelerate_available, is_torch_bf16_available_
if is_torch_available():
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_1
from transformers.trainer import FSDP_MODEL_NAME
else:
is_torch_greater_or_equal_than_2_1 = False
@@ -211,6 +212,19 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
# resume from ckpt
checkpoint = os.path.join(output_dir, "checkpoint-115")
resume_args = args + f"--resume_from_checkpoint {checkpoint}".split()
is_fsdp_ckpt = os.path.isdir(checkpoint) and (
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
any(
FSDP_MODEL_NAME in folder_name
for folder_name in os.listdir(checkpoint)
if os.path.isdir(os.path.join(checkpoint, folder_name))
)
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
or os.path.isfile(os.path.join(checkpoint, f"{FSDP_MODEL_NAME}.bin"))
)
self.assertTrue(is_fsdp_ckpt)
logs_resume = self.run_cmd_and_get_logs(
use_accelerate, sharding_strategy, launcher, script, resume_args, output_dir
)