Fix initialization for missing parameters in from_pretrained under ZeRO-3 (#28245)
* Fix initialization for missing parameters in `from_pretrained` under ZeRO-3 * Test initialization for missing parameters under ZeRO-3 * Add more tests * Only enable deepspeed context for per-module level parameters * Enable deepspeed context only once * Move class definition inside test case body
This commit is contained in:
@@ -225,6 +225,78 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
AutoModel.from_pretrained(T5_TINY)
|
||||
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
|
||||
|
||||
def test_init_zero3_missing_params(self):
|
||||
# test that zero.Init() for missing parameters works correctly under zero3
|
||||
import deepspeed
|
||||
import torch
|
||||
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel
|
||||
|
||||
class TinyGPT2WithUninitializedWeights(GPT2PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.transformer = AutoModel.from_pretrained(GPT2_TINY, config=config)
|
||||
self.new_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
transformer_outputs = self.transformer(*args, **kwargs)
|
||||
hidden_states = transformer_outputs[0]
|
||||
return self.new_head(hidden_states).float()
|
||||
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if module is self.new_head:
|
||||
self.new_head.weight.data.fill_(-100.0)
|
||||
self.new_head.bias.data.fill_(+100.0)
|
||||
|
||||
ds_config = {
|
||||
"train_batch_size": 1,
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
},
|
||||
}
|
||||
|
||||
dschf = HfDeepSpeedConfig(ds_config)
|
||||
|
||||
self.assertTrue(dschf.is_zero3())
|
||||
self.assertTrue(is_deepspeed_zero3_enabled())
|
||||
|
||||
with LoggingLevel(logging.INFO):
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
with CaptureLogger(logger) as cl:
|
||||
model = TinyGPT2WithUninitializedWeights.from_pretrained(GPT2_TINY)
|
||||
self.assertIn("Detected DeepSpeed ZeRO-3", cl.out)
|
||||
self.assertRegex(cl.out, r"newly initialized.*new_head\.bias.*new_head\.weight")
|
||||
with deepspeed.zero.GatheredParameters([model.new_head.weight, model.new_head.bias]):
|
||||
self.assertTrue(
|
||||
torch.allclose(model.new_head.weight, torch.tensor(-100.0, device=model.new_head.weight.device)),
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(model.new_head.bias, torch.tensor(+100.0, device=model.new_head.bias.device)),
|
||||
)
|
||||
|
||||
# now remove zero optimization
|
||||
del ds_config["zero_optimization"]
|
||||
dschf = HfDeepSpeedConfig(ds_config)
|
||||
|
||||
self.assertFalse(dschf.is_zero3())
|
||||
self.assertFalse(is_deepspeed_zero3_enabled())
|
||||
|
||||
with LoggingLevel(logging.INFO):
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
with CaptureLogger(logger) as cl:
|
||||
model = TinyGPT2WithUninitializedWeights.from_pretrained(GPT2_TINY)
|
||||
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
|
||||
self.assertRegex(cl.out, r"newly initialized.*new_head\.bias.*new_head\.weight")
|
||||
self.assertTrue(
|
||||
torch.allclose(model.new_head.weight, torch.tensor(-100.0, device=model.new_head.weight.device)),
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(model.new_head.bias, torch.tensor(+100.0, device=model.new_head.bias.device)),
|
||||
)
|
||||
|
||||
|
||||
class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus):
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user