[tests] fix deepspeed zero3 config for test_stage3_nvme_offload (#31881)
fix config
This commit is contained in:
@@ -545,6 +545,7 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
|
|||||||
ds_config_zero3_dict = self.get_config_dict(ZERO3)
|
ds_config_zero3_dict = self.get_config_dict(ZERO3)
|
||||||
ds_config_zero3_dict["zero_optimization"]["offload_optimizer"] = nvme_config
|
ds_config_zero3_dict["zero_optimization"]["offload_optimizer"] = nvme_config
|
||||||
ds_config_zero3_dict["zero_optimization"]["offload_param"] = nvme_config
|
ds_config_zero3_dict["zero_optimization"]["offload_param"] = nvme_config
|
||||||
|
ds_config_zero3_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True
|
||||||
trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=ds_config_zero3_dict)
|
trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=ds_config_zero3_dict)
|
||||||
with CaptureLogger(deepspeed_logger) as cl:
|
with CaptureLogger(deepspeed_logger) as cl:
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|||||||
Reference in New Issue
Block a user