From 7ecd229ba475dbf78040f368ae86c86bba875442 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Thu, 9 Nov 2023 11:47:24 +0530 Subject: [PATCH] Smangrul/fix failing ds ci tests (#27358) * fix failing DeepSpeed CI tests due to `safetensors` being default * debug * remove debug statements * resolve comments * Update test_deepspeed.py --- tests/deepspeed/test_deepspeed.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 5e8c0d1629..9daad85b02 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -48,7 +48,7 @@ from transformers.testing_utils import ( slow, ) from transformers.trainer_utils import get_last_checkpoint, set_seed -from transformers.utils import WEIGHTS_NAME, is_torch_bf16_gpu_available +from transformers.utils import SAFE_WEIGHTS_NAME, is_torch_bf16_gpu_available if is_torch_available(): @@ -565,8 +565,7 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype): # adapted from TrainerIntegrationCommon.check_saved_checkpoints - - file_list = [WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"] + file_list = [SAFE_WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"] if stage == ZERO2: ds_file_list = ["mp_rank_00_model_states.pt"] @@ -581,7 +580,6 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T for step in range(freq, total, freq): checkpoint = os.path.join(output_dir, f"checkpoint-{step}") self.assertTrue(os.path.isdir(checkpoint), f"[{stage}] {checkpoint} dir is not found") - # common files for filename in file_list: path = os.path.join(checkpoint, filename)