From 976189a6df796a2ff442dd81b022626c840d8c27 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 9 Jan 2024 22:58:21 +0800 Subject: [PATCH] Fix initialization for missing parameters in `from_pretrained` under ZeRO-3 (#28245) * Fix initialization for missing parameters in `from_pretrained` under ZeRO-3 * Test initialization for missing parameters under ZeRO-3 * Add more tests * Only enable deepspeed context for per-module level parameters * Enable deepspeed context only once * Move class definition inside test case body --- src/transformers/modeling_utils.py | 34 +++++++++++--- tests/deepspeed/test_deepspeed.py | 72 ++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 05d74d6542..c093a86096 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -19,6 +19,7 @@ import functools import gc import importlib.metadata import inspect +import itertools import json import os import re @@ -544,10 +545,14 @@ def set_initialized_submodules(model, state_dict_keys): Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state dict. """ + not_initialized_submodules = {} for module_name, module in model.named_modules(): - loaded_keys = [k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")] - if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0: + loaded_keys = {k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")} + if loaded_keys.issuperset(module.state_dict()): module._is_hf_initialized = True + else: + not_initialized_submodules[module_name] = module + return not_initialized_submodules def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): @@ -3917,7 +3922,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix elif add_prefix_to_model: expected_keys = [".".join([prefix, s]) for s in expected_keys] - missing_keys = list(set(expected_keys) - set(loaded_keys)) + missing_keys = sorted(set(expected_keys) - set(loaded_keys)) unexpected_keys = set(loaded_keys) - set(expected_keys) # Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model # buffers @@ -3926,10 +3931,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix model_buffers = {key[len(_prefix) :] if key.startswith(_prefix) else key for key in model_buffers} elif add_prefix_to_model: model_buffers = {".".join([prefix, key]) for key in model_buffers} - unexpected_keys = list(unexpected_keys - model_buffers) + unexpected_keys = sorted(unexpected_keys - model_buffers) model.tie_weights() - if device_map is None and not is_fsdp_enabled(): + if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): ptrs = collections.defaultdict(list) for name, tensor in model.state_dict().items(): id_tensor = id_tensor_storage(tensor) @@ -4000,9 +4005,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix _loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys] else: _loaded_keys = loaded_keys - set_initialized_submodules(model, _loaded_keys) + not_initialized_submodules = set_initialized_submodules(model, _loaded_keys) + else: + not_initialized_submodules = dict(model.named_modules()) # This will only initialize submodules that are not marked as initialized by the line above. - model.apply(model._initialize_weights) + if is_deepspeed_zero3_enabled(): + import deepspeed + + not_initialized_parameters = list( + set( + itertools.chain.from_iterable( + submodule.parameters(recurse=False) for submodule in not_initialized_submodules.values() + ) + ) + ) + with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0): + model.apply(model._initialize_weights) + else: + model.apply(model._initialize_weights) # Set some modules to fp32 if any if keep_in_fp32_modules is not None: diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 14c8f67031..982578d455 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -225,6 +225,78 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): AutoModel.from_pretrained(T5_TINY) self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out) + def test_init_zero3_missing_params(self): + # test that zero.Init() for missing parameters works correctly under zero3 + import deepspeed + import torch + + from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel + + class TinyGPT2WithUninitializedWeights(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.transformer = AutoModel.from_pretrained(GPT2_TINY, config=config) + self.new_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=True) + + def forward(self, *args, **kwargs): + transformer_outputs = self.transformer(*args, **kwargs) + hidden_states = transformer_outputs[0] + return self.new_head(hidden_states).float() + + def _init_weights(self, module): + super()._init_weights(module) + if module is self.new_head: + self.new_head.weight.data.fill_(-100.0) + self.new_head.bias.data.fill_(+100.0) + + ds_config = { + "train_batch_size": 1, + "zero_optimization": { + "stage": 3, + }, + } + + dschf = HfDeepSpeedConfig(ds_config) + + self.assertTrue(dschf.is_zero3()) + self.assertTrue(is_deepspeed_zero3_enabled()) + + with LoggingLevel(logging.INFO): + with mockenv_context(**self.dist_env_1_gpu): + logger = logging.get_logger("transformers.modeling_utils") + with CaptureLogger(logger) as cl: + model = TinyGPT2WithUninitializedWeights.from_pretrained(GPT2_TINY) + self.assertIn("Detected DeepSpeed ZeRO-3", cl.out) + self.assertRegex(cl.out, r"newly initialized.*new_head\.bias.*new_head\.weight") + with deepspeed.zero.GatheredParameters([model.new_head.weight, model.new_head.bias]): + self.assertTrue( + torch.allclose(model.new_head.weight, torch.tensor(-100.0, device=model.new_head.weight.device)), + ) + self.assertTrue( + torch.allclose(model.new_head.bias, torch.tensor(+100.0, device=model.new_head.bias.device)), + ) + + # now remove zero optimization + del ds_config["zero_optimization"] + dschf = HfDeepSpeedConfig(ds_config) + + self.assertFalse(dschf.is_zero3()) + self.assertFalse(is_deepspeed_zero3_enabled()) + + with LoggingLevel(logging.INFO): + with mockenv_context(**self.dist_env_1_gpu): + logger = logging.get_logger("transformers.modeling_utils") + with CaptureLogger(logger) as cl: + model = TinyGPT2WithUninitializedWeights.from_pretrained(GPT2_TINY) + self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out) + self.assertRegex(cl.out, r"newly initialized.*new_head\.bias.*new_head\.weight") + self.assertTrue( + torch.allclose(model.new_head.weight, torch.tensor(-100.0, device=model.new_head.weight.device)), + ) + self.assertTrue( + torch.allclose(model.new_head.bias, torch.tensor(+100.0, device=model.new_head.bias.device)), + ) + class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus): def setUp(self):