Smangrul/fix failing ds ci tests (#27358)
* fix failing DeepSpeed CI tests due to `safetensors` being default * debug * remove debug statements * resolve comments * Update test_deepspeed.py
This commit is contained in:
committed by
GitHub
parent
ced9fd86f5
commit
7ecd229ba4
@@ -48,7 +48,7 @@ from transformers.testing_utils import (
|
|||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import get_last_checkpoint, set_seed
|
from transformers.trainer_utils import get_last_checkpoint, set_seed
|
||||||
from transformers.utils import WEIGHTS_NAME, is_torch_bf16_gpu_available
|
from transformers.utils import SAFE_WEIGHTS_NAME, is_torch_bf16_gpu_available
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -565,8 +565,7 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
|
|||||||
|
|
||||||
def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype):
|
def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype):
|
||||||
# adapted from TrainerIntegrationCommon.check_saved_checkpoints
|
# adapted from TrainerIntegrationCommon.check_saved_checkpoints
|
||||||
|
file_list = [SAFE_WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
|
||||||
file_list = [WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
|
|
||||||
|
|
||||||
if stage == ZERO2:
|
if stage == ZERO2:
|
||||||
ds_file_list = ["mp_rank_00_model_states.pt"]
|
ds_file_list = ["mp_rank_00_model_states.pt"]
|
||||||
@@ -581,7 +580,6 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
|
|||||||
for step in range(freq, total, freq):
|
for step in range(freq, total, freq):
|
||||||
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
|
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
|
||||||
self.assertTrue(os.path.isdir(checkpoint), f"[{stage}] {checkpoint} dir is not found")
|
self.assertTrue(os.path.isdir(checkpoint), f"[{stage}] {checkpoint} dir is not found")
|
||||||
|
|
||||||
# common files
|
# common files
|
||||||
for filename in file_list:
|
for filename in file_list:
|
||||||
path = os.path.join(checkpoint, filename)
|
path = os.path.join(checkpoint, filename)
|
||||||
|
|||||||
Reference in New Issue
Block a user