Fix loading zero3 weights (#36455)
* Check if fixes * Fix zero3 loading * Quality * Fix marc nit * Add fast tests * Migrate to integrations.deepspeed rather than modeling_utils * Style
This commit is contained in:
@@ -170,7 +170,6 @@ params_with_optims_and_schedulers = list(itertools.product(stages, dtypes, optim
|
||||
|
||||
|
||||
@require_deepspeed
|
||||
@require_torch_accelerator
|
||||
class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
"""
|
||||
Testing non-Trainer DeepSpeed integration
|
||||
@@ -194,6 +193,42 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
# reset the ds config global so that tests state doesn't leak
|
||||
unset_hf_deepspeed_config()
|
||||
|
||||
def test_init_zero3(self):
|
||||
# test that zero.Init() works correctly
|
||||
ds_config = {
|
||||
"train_batch_size": 1,
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
},
|
||||
}
|
||||
|
||||
dschf = HfDeepSpeedConfig(ds_config)
|
||||
|
||||
self.assertTrue(dschf.is_zero3())
|
||||
self.assertTrue(is_deepspeed_zero3_enabled())
|
||||
|
||||
with LoggingLevel(logging.INFO):
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
with CaptureLogger(logger) as cl:
|
||||
AutoModel.from_pretrained(T5_TINY)
|
||||
self.assertIn("Detected DeepSpeed ZeRO-3", cl.out)
|
||||
|
||||
# now remove zero optimization
|
||||
del ds_config["zero_optimization"]
|
||||
dschf = HfDeepSpeedConfig(ds_config)
|
||||
|
||||
self.assertFalse(dschf.is_zero3())
|
||||
self.assertFalse(is_deepspeed_zero3_enabled())
|
||||
|
||||
with LoggingLevel(logging.INFO):
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
with CaptureLogger(logger) as cl:
|
||||
AutoModel.from_pretrained(T5_TINY)
|
||||
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_init_zero3_fp16(self):
|
||||
# test that zero.Init() works correctly under zero3/fp16
|
||||
ds_config = {
|
||||
@@ -201,6 +236,9 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
|
||||
dschf = HfDeepSpeedConfig(ds_config)
|
||||
|
||||
Reference in New Issue
Block a user