Pass state dict (#35234)

* Pass state_dict argument to get_peft_model_state_dict

* Style fix

* Change arguments order
This commit is contained in:
Artem Kudisov
2025-03-20 13:54:59 +03:00
committed by GitHub
parent 957b05b413
commit 63380b77d4
2 changed files with 7 additions and 3 deletions

View File

@@ -446,7 +446,7 @@ class PeftAdapterMixin:
return self.active_adapters()[0]
def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
def get_adapter_state_dict(self, adapter_name: Optional[str] = None, state_dict: Optional[dict] = None) -> dict:
"""
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
@@ -457,6 +457,10 @@ class PeftAdapterMixin:
Args:
adapter_name (`str`, *optional*):
The name of the adapter to get the state dict from. If no name is passed, the active adapter is used.
state_dict (nested dictionary of `torch.Tensor`, *optional*)
The state dictionary of the model. Will default to `self.state_dict()`, but can be used if special
precautions need to be taken when recovering the state dictionary of a model (like when using model
parallelism).
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
@@ -468,7 +472,7 @@ class PeftAdapterMixin:
if adapter_name is None:
adapter_name = self.active_adapters()[0]
adapter_state_dict = get_peft_model_state_dict(self, adapter_name=adapter_name)
adapter_state_dict = get_peft_model_state_dict(self, state_dict=state_dict, adapter_name=adapter_name)
return adapter_state_dict
def _dispatch_accelerate_model(

View File

@@ -3359,7 +3359,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
logger.info(
"Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved."
)
state_dict = model_to_save.get_adapter_state_dict()
state_dict = model_to_save.get_adapter_state_dict(state_dict=state_dict)
if save_peft_format:
logger.info(