Pass state dict (#35234)
* Pass state_dict argument to get_peft_model_state_dict * Style fix * Change arguments order
This commit is contained in:
@@ -446,7 +446,7 @@ class PeftAdapterMixin:
|
||||
|
||||
return self.active_adapters()[0]
|
||||
|
||||
def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
|
||||
def get_adapter_state_dict(self, adapter_name: Optional[str] = None, state_dict: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
||||
official documentation: https://huggingface.co/docs/peft
|
||||
@@ -457,6 +457,10 @@ class PeftAdapterMixin:
|
||||
Args:
|
||||
adapter_name (`str`, *optional*):
|
||||
The name of the adapter to get the state dict from. If no name is passed, the active adapter is used.
|
||||
state_dict (nested dictionary of `torch.Tensor`, *optional*)
|
||||
The state dictionary of the model. Will default to `self.state_dict()`, but can be used if special
|
||||
precautions need to be taken when recovering the state dictionary of a model (like when using model
|
||||
parallelism).
|
||||
"""
|
||||
check_peft_version(min_version=MIN_PEFT_VERSION)
|
||||
|
||||
@@ -468,7 +472,7 @@ class PeftAdapterMixin:
|
||||
if adapter_name is None:
|
||||
adapter_name = self.active_adapters()[0]
|
||||
|
||||
adapter_state_dict = get_peft_model_state_dict(self, adapter_name=adapter_name)
|
||||
adapter_state_dict = get_peft_model_state_dict(self, state_dict=state_dict, adapter_name=adapter_name)
|
||||
return adapter_state_dict
|
||||
|
||||
def _dispatch_accelerate_model(
|
||||
|
||||
@@ -3359,7 +3359,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
logger.info(
|
||||
"Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved."
|
||||
)
|
||||
state_dict = model_to_save.get_adapter_state_dict()
|
||||
state_dict = model_to_save.get_adapter_state_dict(state_dict=state_dict)
|
||||
|
||||
if save_peft_format:
|
||||
logger.info(
|
||||
|
||||
Reference in New Issue
Block a user