Fix initialization for missing parameters in from_pretrained under ZeRO-3 (#28245)
* Fix initialization for missing parameters in `from_pretrained` under ZeRO-3 * Test initialization for missing parameters under ZeRO-3 * Add more tests * Only enable deepspeed context for per-module level parameters * Enable deepspeed context only once * Move class definition inside test case body
This commit is contained in:
@@ -19,6 +19,7 @@ import functools
|
||||
import gc
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@@ -544,10 +545,14 @@ def set_initialized_submodules(model, state_dict_keys):
|
||||
Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
|
||||
dict.
|
||||
"""
|
||||
not_initialized_submodules = {}
|
||||
for module_name, module in model.named_modules():
|
||||
loaded_keys = [k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")]
|
||||
if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0:
|
||||
loaded_keys = {k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")}
|
||||
if loaded_keys.issuperset(module.state_dict()):
|
||||
module._is_hf_initialized = True
|
||||
else:
|
||||
not_initialized_submodules[module_name] = module
|
||||
return not_initialized_submodules
|
||||
|
||||
|
||||
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
||||
@@ -3917,7 +3922,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
elif add_prefix_to_model:
|
||||
expected_keys = [".".join([prefix, s]) for s in expected_keys]
|
||||
|
||||
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
||||
missing_keys = sorted(set(expected_keys) - set(loaded_keys))
|
||||
unexpected_keys = set(loaded_keys) - set(expected_keys)
|
||||
# Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model
|
||||
# buffers
|
||||
@@ -3926,10 +3931,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
model_buffers = {key[len(_prefix) :] if key.startswith(_prefix) else key for key in model_buffers}
|
||||
elif add_prefix_to_model:
|
||||
model_buffers = {".".join([prefix, key]) for key in model_buffers}
|
||||
unexpected_keys = list(unexpected_keys - model_buffers)
|
||||
unexpected_keys = sorted(unexpected_keys - model_buffers)
|
||||
|
||||
model.tie_weights()
|
||||
if device_map is None and not is_fsdp_enabled():
|
||||
if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
|
||||
ptrs = collections.defaultdict(list)
|
||||
for name, tensor in model.state_dict().items():
|
||||
id_tensor = id_tensor_storage(tensor)
|
||||
@@ -4000,8 +4005,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
|
||||
else:
|
||||
_loaded_keys = loaded_keys
|
||||
set_initialized_submodules(model, _loaded_keys)
|
||||
not_initialized_submodules = set_initialized_submodules(model, _loaded_keys)
|
||||
else:
|
||||
not_initialized_submodules = dict(model.named_modules())
|
||||
# This will only initialize submodules that are not marked as initialized by the line above.
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
not_initialized_parameters = list(
|
||||
set(
|
||||
itertools.chain.from_iterable(
|
||||
submodule.parameters(recurse=False) for submodule in not_initialized_submodules.values()
|
||||
)
|
||||
)
|
||||
)
|
||||
with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
|
||||
model.apply(model._initialize_weights)
|
||||
else:
|
||||
model.apply(model._initialize_weights)
|
||||
|
||||
# Set some modules to fp32 if any
|
||||
|
||||
@@ -225,6 +225,78 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
AutoModel.from_pretrained(T5_TINY)
|
||||
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
|
||||
|
||||
def test_init_zero3_missing_params(self):
|
||||
# test that zero.Init() for missing parameters works correctly under zero3
|
||||
import deepspeed
|
||||
import torch
|
||||
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel
|
||||
|
||||
class TinyGPT2WithUninitializedWeights(GPT2PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.transformer = AutoModel.from_pretrained(GPT2_TINY, config=config)
|
||||
self.new_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
transformer_outputs = self.transformer(*args, **kwargs)
|
||||
hidden_states = transformer_outputs[0]
|
||||
return self.new_head(hidden_states).float()
|
||||
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if module is self.new_head:
|
||||
self.new_head.weight.data.fill_(-100.0)
|
||||
self.new_head.bias.data.fill_(+100.0)
|
||||
|
||||
ds_config = {
|
||||
"train_batch_size": 1,
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
},
|
||||
}
|
||||
|
||||
dschf = HfDeepSpeedConfig(ds_config)
|
||||
|
||||
self.assertTrue(dschf.is_zero3())
|
||||
self.assertTrue(is_deepspeed_zero3_enabled())
|
||||
|
||||
with LoggingLevel(logging.INFO):
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
with CaptureLogger(logger) as cl:
|
||||
model = TinyGPT2WithUninitializedWeights.from_pretrained(GPT2_TINY)
|
||||
self.assertIn("Detected DeepSpeed ZeRO-3", cl.out)
|
||||
self.assertRegex(cl.out, r"newly initialized.*new_head\.bias.*new_head\.weight")
|
||||
with deepspeed.zero.GatheredParameters([model.new_head.weight, model.new_head.bias]):
|
||||
self.assertTrue(
|
||||
torch.allclose(model.new_head.weight, torch.tensor(-100.0, device=model.new_head.weight.device)),
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(model.new_head.bias, torch.tensor(+100.0, device=model.new_head.bias.device)),
|
||||
)
|
||||
|
||||
# now remove zero optimization
|
||||
del ds_config["zero_optimization"]
|
||||
dschf = HfDeepSpeedConfig(ds_config)
|
||||
|
||||
self.assertFalse(dschf.is_zero3())
|
||||
self.assertFalse(is_deepspeed_zero3_enabled())
|
||||
|
||||
with LoggingLevel(logging.INFO):
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
with CaptureLogger(logger) as cl:
|
||||
model = TinyGPT2WithUninitializedWeights.from_pretrained(GPT2_TINY)
|
||||
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
|
||||
self.assertRegex(cl.out, r"newly initialized.*new_head\.bias.*new_head\.weight")
|
||||
self.assertTrue(
|
||||
torch.allclose(model.new_head.weight, torch.tensor(-100.0, device=model.new_head.weight.device)),
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(model.new_head.bias, torch.tensor(+100.0, device=model.new_head.bias.device)),
|
||||
)
|
||||
|
||||
|
||||
class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus):
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user