fix resuming from ckpt when using FSDP with FULL_STATE_DICT (#27891)
* fix resuming from ckpt when suing FSDP with FULL_STATE_DICT * update tests * fix tests
This commit is contained in:
committed by
GitHub
parent
ebfdb9ca62
commit
238d2e3c44
@@ -2033,10 +2033,15 @@ class Trainer:
|
|||||||
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
|
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
|
||||||
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
|
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
|
||||||
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
|
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and any(
|
is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and (
|
||||||
FSDP_MODEL_NAME in folder_name
|
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
|
||||||
for folder_name in os.listdir(resume_from_checkpoint)
|
any(
|
||||||
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
|
FSDP_MODEL_NAME in folder_name
|
||||||
|
for folder_name in os.listdir(resume_from_checkpoint)
|
||||||
|
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
|
||||||
|
)
|
||||||
|
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
|
||||||
|
or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_fsdp_ckpt and not self.is_fsdp_enabled:
|
if is_fsdp_ckpt and not self.is_fsdp_enabled:
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ from transformers.utils import is_accelerate_available, is_torch_bf16_available_
|
|||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_1
|
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_1
|
||||||
|
from transformers.trainer import FSDP_MODEL_NAME
|
||||||
else:
|
else:
|
||||||
is_torch_greater_or_equal_than_2_1 = False
|
is_torch_greater_or_equal_than_2_1 = False
|
||||||
|
|
||||||
@@ -211,6 +212,19 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
# resume from ckpt
|
# resume from ckpt
|
||||||
checkpoint = os.path.join(output_dir, "checkpoint-115")
|
checkpoint = os.path.join(output_dir, "checkpoint-115")
|
||||||
resume_args = args + f"--resume_from_checkpoint {checkpoint}".split()
|
resume_args = args + f"--resume_from_checkpoint {checkpoint}".split()
|
||||||
|
|
||||||
|
is_fsdp_ckpt = os.path.isdir(checkpoint) and (
|
||||||
|
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
|
||||||
|
any(
|
||||||
|
FSDP_MODEL_NAME in folder_name
|
||||||
|
for folder_name in os.listdir(checkpoint)
|
||||||
|
if os.path.isdir(os.path.join(checkpoint, folder_name))
|
||||||
|
)
|
||||||
|
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
|
||||||
|
or os.path.isfile(os.path.join(checkpoint, f"{FSDP_MODEL_NAME}.bin"))
|
||||||
|
)
|
||||||
|
self.assertTrue(is_fsdp_ckpt)
|
||||||
|
|
||||||
logs_resume = self.run_cmd_and_get_logs(
|
logs_resume = self.run_cmd_and_get_logs(
|
||||||
use_accelerate, sharding_strategy, launcher, script, resume_args, output_dir
|
use_accelerate, sharding_strategy, launcher, script, resume_args, output_dir
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user