[tests] fix deepspeed zero3 config for test_stage3_nvme_offload (#31881)

fix config
This commit is contained in:
Fanli Lin
2024-07-16 22:11:37 +08:00
committed by GitHub
parent e0dfd7bcaf
commit 25e5e3fa56

View File

@@ -545,6 +545,7 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
ds_config_zero3_dict = self.get_config_dict(ZERO3)
ds_config_zero3_dict["zero_optimization"]["offload_optimizer"] = nvme_config
ds_config_zero3_dict["zero_optimization"]["offload_param"] = nvme_config
ds_config_zero3_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True
trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=ds_config_zero3_dict)
with CaptureLogger(deepspeed_logger) as cl:
trainer.train()