From 84aa13dd85ce5ec2023561ca304c5b41343dd347 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Sat, 5 Apr 2025 17:05:45 +0200 Subject: [PATCH] Fix deepspeed loading (#37281) * Update modeling_utils.py * Update modeling_utils.py * fix and remove all imports * Update modeling_utils.py * Update modeling_utils.py * style * Update modeling_utils.py --- src/transformers/modeling_utils.py | 48 ++++++++---------------------- 1 file changed, 12 insertions(+), 36 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 218c8dc6e9..cbeb857906 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -57,7 +57,7 @@ from .configuration_utils import PretrainedConfig from .dynamic_module_utils import custom_object_save from .generation import CompileConfig, GenerationConfig, GenerationMixin from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled -from .integrations.deepspeed import _load_state_dict_into_zero3_model +from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available from .integrations.flash_attention import flash_attention_forward from .integrations.flex_attention import flex_attention_forward from .integrations.sdpa_attention import sdpa_attention_forward @@ -153,6 +153,10 @@ if is_safetensors_available(): from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file + +if is_deepspeed_available(): + import deepspeed + logger = logging.get_logger(__name__) @@ -2021,8 +2025,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called: - import deepspeed - logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") # this immediately partitions the model across all gpus, to avoid the overhead in time # and memory copying it on CPU or each GPU first @@ -2662,8 +2664,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Since we are basically reusing the same old embeddings with new weight values, gathering is required is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None if is_deepspeed_zero3_enabled() and not is_quantized: - import deepspeed - with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None): vocab_size = model_embeds.weight.shape[0] else: @@ -2694,8 +2694,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Update new_num_tokens with the actual size of new_embeddings if pad_to_multiple_of is not None: if is_deepspeed_zero3_enabled() and not is_quantized: - import deepspeed - with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None): new_num_tokens = new_embeddings.weight.shape[0] else: @@ -2784,8 +2782,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None if is_deepspeed_zero3_enabled() and not is_quantized: - import deepspeed - with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None): old_num_tokens, old_embedding_dim = old_embeddings.weight.size() else: @@ -2830,8 +2826,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix added_num_tokens = new_num_tokens - old_num_tokens if is_deepspeed_zero3_enabled() and not is_quantized: - import deepspeed - with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None): self._init_added_embeddings_weights_with_mean( old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens @@ -2847,8 +2841,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix n = min(old_num_tokens, new_num_tokens) if is_deepspeed_zero3_enabled() and not is_quantized: - import deepspeed - params = [old_embeddings.weight, new_embeddings.weight] with deepspeed.zero.GatheredParameters(params, modifier_rank=0): new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] @@ -2859,8 +2851,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # This ensures correct functionality when a Custom Embedding class is passed as input. # The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979) if is_deepspeed_zero3_enabled() and not is_quantized: - import deepspeed - params = [old_embeddings.weight, new_embeddings.weight] with deepspeed.zero.GatheredParameters(params, modifier_rank=0): old_embeddings.weight = new_embeddings.weight @@ -2918,8 +2908,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None if is_deepspeed_zero3_enabled() and not is_quantized: - import deepspeed - with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None): old_num_tokens, old_lm_head_dim = ( old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size() @@ -2970,8 +2958,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix added_num_tokens = new_num_tokens - old_num_tokens if is_deepspeed_zero3_enabled() and not is_quantized: - import deepspeed - params = [old_lm_head.weight] if has_new_lm_head_bias: params += [old_lm_head.bias] @@ -2992,8 +2978,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix num_tokens_to_copy = min(old_num_tokens, new_num_tokens) if is_deepspeed_zero3_enabled() and not is_quantized: - import deepspeed - params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias] with deepspeed.zero.GatheredParameters(params, modifier_rank=0): self._copy_lm_head_original_to_resized( @@ -3738,14 +3722,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix return super().float(*args) @classmethod - def get_init_context( - cls: Type[SpecificPreTrainedModelType], - is_quantized=None, - _is_ds_init_called=None, - ): + def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool): if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called: - import deepspeed - logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") init_contexts = [ deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), @@ -4644,6 +4622,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ): # Useful flags is_quantized = hf_quantizer is not None + is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [ + QuantizationMethod.HQQ, + QuantizationMethod.BITS_AND_BYTES, + ] # Get all the keys of the state dicts that we have to initialize the model if sharded_metadata is not None: @@ -4805,15 +4787,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys) # Warmup cuda to load the weights much faster on devices - if device_map is not None: # and hf_quantizer is None: + if device_map is not None: expanded_device_map = expand_device_map(device_map, expected_keys) caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4) error_msgs = [] - is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [ - QuantizationMethod.HQQ, - QuantizationMethod.BITS_AND_BYTES, - ] # Iterate on all the shards to load the weights for shard_file in checkpoint_files: # Skip the load for shards that only contain disk-offloaded weights @@ -4821,7 +4799,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix continue map_location = "cpu" - if shard_file.endswith(".safetensors") and not is_hqq_or_bnb: + if shard_file.endswith(".safetensors") and not is_hqq_or_bnb and not is_deepspeed_zero3_enabled(): map_location = "meta" elif ( device_map is not None @@ -5267,8 +5245,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix not_initialized_submodules = dict(self.named_modules()) # This will only initialize submodules that are not marked as initialized by the line above. if is_deepspeed_zero3_enabled() and not is_quantized: - import deepspeed - not_initialized_parameters = list( set( itertools.chain.from_iterable(