Fix initialization for missing parameters in from_pretrained under ZeRO-3 (#28245)

* Fix initialization for missing parameters in `from_pretrained` under ZeRO-3

* Test initialization for missing parameters under ZeRO-3

* Add more tests

* Only enable deepspeed context for per-module level parameters

* Enable deepspeed context only once

* Move class definition inside test case body
This commit is contained in:
Xuehai Pan
2024-01-09 22:58:21 +08:00
committed by GitHub
parent 357971ec36
commit 976189a6df
2 changed files with 99 additions and 7 deletions

View File

@@ -225,6 +225,78 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
AutoModel.from_pretrained(T5_TINY)
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
def test_init_zero3_missing_params(self):
# test that zero.Init() for missing parameters works correctly under zero3
import deepspeed
import torch
from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel
class TinyGPT2WithUninitializedWeights(GPT2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.transformer = AutoModel.from_pretrained(GPT2_TINY, config=config)
self.new_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=True)
def forward(self, *args, **kwargs):
transformer_outputs = self.transformer(*args, **kwargs)
hidden_states = transformer_outputs[0]
return self.new_head(hidden_states).float()
def _init_weights(self, module):
super()._init_weights(module)
if module is self.new_head:
self.new_head.weight.data.fill_(-100.0)
self.new_head.bias.data.fill_(+100.0)
ds_config = {
"train_batch_size": 1,
"zero_optimization": {
"stage": 3,
},
}
dschf = HfDeepSpeedConfig(ds_config)
self.assertTrue(dschf.is_zero3())
self.assertTrue(is_deepspeed_zero3_enabled())
with LoggingLevel(logging.INFO):
with mockenv_context(**self.dist_env_1_gpu):
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
model = TinyGPT2WithUninitializedWeights.from_pretrained(GPT2_TINY)
self.assertIn("Detected DeepSpeed ZeRO-3", cl.out)
self.assertRegex(cl.out, r"newly initialized.*new_head\.bias.*new_head\.weight")
with deepspeed.zero.GatheredParameters([model.new_head.weight, model.new_head.bias]):
self.assertTrue(
torch.allclose(model.new_head.weight, torch.tensor(-100.0, device=model.new_head.weight.device)),
)
self.assertTrue(
torch.allclose(model.new_head.bias, torch.tensor(+100.0, device=model.new_head.bias.device)),
)
# now remove zero optimization
del ds_config["zero_optimization"]
dschf = HfDeepSpeedConfig(ds_config)
self.assertFalse(dschf.is_zero3())
self.assertFalse(is_deepspeed_zero3_enabled())
with LoggingLevel(logging.INFO):
with mockenv_context(**self.dist_env_1_gpu):
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
model = TinyGPT2WithUninitializedWeights.from_pretrained(GPT2_TINY)
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
self.assertRegex(cl.out, r"newly initialized.*new_head\.bias.*new_head\.weight")
self.assertTrue(
torch.allclose(model.new_head.weight, torch.tensor(-100.0, device=model.new_head.weight.device)),
)
self.assertTrue(
torch.allclose(model.new_head.bias, torch.tensor(+100.0, device=model.new_head.bias.device)),
)
class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus):
def setUp(self):