Guard DeepSpeed imports (#37755)
* Guard DeepSpeed imports * Fix import * Import deepspeed consistently --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -57,7 +57,7 @@ from .dynamic_module_utils import custom_object_save
|
|||||||
from .generation import CompileConfig, GenerationConfig
|
from .generation import CompileConfig, GenerationConfig
|
||||||
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
||||||
from .integrations.accelerate import find_tied_parameters, init_empty_weights
|
from .integrations.accelerate import find_tied_parameters, init_empty_weights
|
||||||
from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available
|
from .integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||||
from .integrations.flash_attention import flash_attention_forward
|
from .integrations.flash_attention import flash_attention_forward
|
||||||
from .integrations.flex_attention import flex_attention_forward
|
from .integrations.flex_attention import flex_attention_forward
|
||||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||||
@@ -154,9 +154,6 @@ if is_safetensors_available():
|
|||||||
from safetensors.torch import save_file as safe_save_file
|
from safetensors.torch import save_file as safe_save_file
|
||||||
|
|
||||||
|
|
||||||
if is_deepspeed_available():
|
|
||||||
import deepspeed
|
|
||||||
|
|
||||||
if is_kernels_available():
|
if is_kernels_available():
|
||||||
from kernels import get_kernel
|
from kernels import get_kernel
|
||||||
|
|
||||||
@@ -2007,6 +2004,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||||
# this immediately partitions the model across all gpus, to avoid the overhead in time
|
# this immediately partitions the model across all gpus, to avoid the overhead in time
|
||||||
# and memory copying it on CPU or each GPU first
|
# and memory copying it on CPU or each GPU first
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()]
|
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()]
|
||||||
with ContextManagers(init_contexts):
|
with ContextManagers(init_contexts):
|
||||||
model = cls(config, **kwargs)
|
model = cls(config, **kwargs)
|
||||||
@@ -2702,6 +2701,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
# Since we are basically reusing the same old embeddings with new weight values, gathering is required
|
# Since we are basically reusing the same old embeddings with new weight values, gathering is required
|
||||||
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
|
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
|
||||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None):
|
with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None):
|
||||||
vocab_size = model_embeds.weight.shape[0]
|
vocab_size = model_embeds.weight.shape[0]
|
||||||
else:
|
else:
|
||||||
@@ -2732,6 +2733,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
# Update new_num_tokens with the actual size of new_embeddings
|
# Update new_num_tokens with the actual size of new_embeddings
|
||||||
if pad_to_multiple_of is not None:
|
if pad_to_multiple_of is not None:
|
||||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
|
with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
|
||||||
new_num_tokens = new_embeddings.weight.shape[0]
|
new_num_tokens = new_embeddings.weight.shape[0]
|
||||||
else:
|
else:
|
||||||
@@ -2820,6 +2823,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
|
|
||||||
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
|
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
|
||||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
|
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
|
||||||
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
||||||
else:
|
else:
|
||||||
@@ -2864,6 +2869,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
|
|
||||||
added_num_tokens = new_num_tokens - old_num_tokens
|
added_num_tokens = new_num_tokens - old_num_tokens
|
||||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
|
with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
|
||||||
self._init_added_embeddings_weights_with_mean(
|
self._init_added_embeddings_weights_with_mean(
|
||||||
old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
|
old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
|
||||||
@@ -2879,6 +2886,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
n = min(old_num_tokens, new_num_tokens)
|
n = min(old_num_tokens, new_num_tokens)
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
params = [old_embeddings.weight, new_embeddings.weight]
|
params = [old_embeddings.weight, new_embeddings.weight]
|
||||||
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
||||||
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
|
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
|
||||||
@@ -2889,6 +2898,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
# This ensures correct functionality when a Custom Embedding class is passed as input.
|
# This ensures correct functionality when a Custom Embedding class is passed as input.
|
||||||
# The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979)
|
# The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979)
|
||||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
params = [old_embeddings.weight, new_embeddings.weight]
|
params = [old_embeddings.weight, new_embeddings.weight]
|
||||||
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
||||||
old_embeddings.weight = new_embeddings.weight
|
old_embeddings.weight = new_embeddings.weight
|
||||||
@@ -2941,11 +2952,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
`torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is
|
`torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is
|
||||||
`None`
|
`None`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if new_num_tokens is None:
|
if new_num_tokens is None:
|
||||||
return old_lm_head
|
return old_lm_head
|
||||||
|
|
||||||
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
|
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
|
||||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
|
with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
|
||||||
old_num_tokens, old_lm_head_dim = (
|
old_num_tokens, old_lm_head_dim = (
|
||||||
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
|
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
|
||||||
@@ -2996,6 +3010,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
|
|
||||||
added_num_tokens = new_num_tokens - old_num_tokens
|
added_num_tokens = new_num_tokens - old_num_tokens
|
||||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
params = [old_lm_head.weight]
|
params = [old_lm_head.weight]
|
||||||
if has_new_lm_head_bias:
|
if has_new_lm_head_bias:
|
||||||
params += [old_lm_head.bias]
|
params += [old_lm_head.bias]
|
||||||
@@ -3016,6 +3032,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
|
params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
|
||||||
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
||||||
self._copy_lm_head_original_to_resized(
|
self._copy_lm_head_original_to_resized(
|
||||||
@@ -3762,6 +3780,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
|
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
init_contexts = [no_init_weights()]
|
init_contexts = [no_init_weights()]
|
||||||
# We cannot initialize the model on meta device with deepspeed when not quantized
|
# We cannot initialize the model on meta device with deepspeed when not quantized
|
||||||
if not is_quantized and not _is_ds_init_called:
|
if not is_quantized and not _is_ds_init_called:
|
||||||
@@ -5349,6 +5369,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
not_initialized_submodules = dict(self.named_modules())
|
not_initialized_submodules = dict(self.named_modules())
|
||||||
# This will only initialize submodules that are not marked as initialized by the line above.
|
# This will only initialize submodules that are not marked as initialized by the line above.
|
||||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
not_initialized_parameters = list(
|
not_initialized_parameters = list(
|
||||||
set(
|
set(
|
||||||
itertools.chain.from_iterable(
|
itertools.chain.from_iterable(
|
||||||
|
|||||||
Reference in New Issue
Block a user