Fix init empty weights without accelerate (#37337)
* add the integration * Update accelerate.py * Update accelerate.py * add find_tied_params as well * Update accelerate.py * add where copied from * simplify * add error
This commit is contained in:
@@ -57,6 +57,7 @@ from .configuration_utils import PretrainedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation import CompileConfig, GenerationConfig, GenerationMixin
|
||||
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
||||
from .integrations.accelerate import find_tied_parameters, init_empty_weights
|
||||
from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available
|
||||
from .integrations.flash_attention import flash_attention_forward
|
||||
from .integrations.flex_attention import flex_attention_forward
|
||||
@@ -131,12 +132,11 @@ XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
|
||||
from accelerate import dispatch_model, infer_auto_device_map
|
||||
from accelerate.hooks import add_hook_to_module
|
||||
from accelerate.utils import (
|
||||
check_tied_parameters_on_same_device,
|
||||
extract_model_from_parallel,
|
||||
find_tied_parameters,
|
||||
get_balanced_memory,
|
||||
get_max_memory,
|
||||
load_offloaded_weights,
|
||||
@@ -4135,6 +4135,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if device_map is not None:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
|
||||
if not is_accelerate_available():
|
||||
raise ValueError(
|
||||
"Using a `device_map` or `tp_plan` requires `accelerate`. You can install it with `pip install accelerate`"
|
||||
)
|
||||
|
||||
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
|
||||
if load_in_4bit or load_in_8bit:
|
||||
|
||||
Reference in New Issue
Block a user