Fix loading zero3 weights (#36455)

* Check if fixes

* Fix zero3 loading

* Quality

* Fix marc nit

* Add fast tests

* Migrate to integrations.deepspeed rather than modeling_utils

* Style
This commit is contained in:
Zach Mueller
2025-03-03 09:05:58 -05:00
committed by GitHub
parent dcbdf7e962
commit 4d8259d245
3 changed files with 105 additions and 3 deletions

View File

@@ -170,7 +170,6 @@ params_with_optims_and_schedulers = list(itertools.product(stages, dtypes, optim
@require_deepspeed
@require_torch_accelerator
class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
"""
Testing non-Trainer DeepSpeed integration
@@ -194,6 +193,42 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
# reset the ds config global so that tests state doesn't leak
unset_hf_deepspeed_config()
def test_init_zero3(self):
# test that zero.Init() works correctly
ds_config = {
"train_batch_size": 1,
"zero_optimization": {
"stage": 3,
},
}
dschf = HfDeepSpeedConfig(ds_config)
self.assertTrue(dschf.is_zero3())
self.assertTrue(is_deepspeed_zero3_enabled())
with LoggingLevel(logging.INFO):
with mockenv_context(**self.dist_env_1_gpu):
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
AutoModel.from_pretrained(T5_TINY)
self.assertIn("Detected DeepSpeed ZeRO-3", cl.out)
# now remove zero optimization
del ds_config["zero_optimization"]
dschf = HfDeepSpeedConfig(ds_config)
self.assertFalse(dschf.is_zero3())
self.assertFalse(is_deepspeed_zero3_enabled())
with LoggingLevel(logging.INFO):
with mockenv_context(**self.dist_env_1_gpu):
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
AutoModel.from_pretrained(T5_TINY)
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
@require_torch_accelerator
def test_init_zero3_fp16(self):
# test that zero.Init() works correctly under zero3/fp16
ds_config = {
@@ -201,6 +236,9 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
"zero_optimization": {
"stage": 3,
},
"fp16": {
"enabled": True,
},
}
dschf = HfDeepSpeedConfig(ds_config)