Pass state dict (#35234)

* Pass state_dict argument to get_peft_model_state_dict

* Style fix

* Change arguments order
This commit is contained in:
Artem Kudisov
2025-03-20 13:54:59 +03:00
committed by GitHub
parent 957b05b413
commit 63380b77d4
2 changed files with 7 additions and 3 deletions

View File

@@ -446,7 +446,7 @@ class PeftAdapterMixin:
return self.active_adapters()[0] return self.active_adapters()[0]
def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict: def get_adapter_state_dict(self, adapter_name: Optional[str] = None, state_dict: Optional[dict] = None) -> dict:
""" """
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft official documentation: https://huggingface.co/docs/peft
@@ -457,6 +457,10 @@ class PeftAdapterMixin:
Args: Args:
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
The name of the adapter to get the state dict from. If no name is passed, the active adapter is used. The name of the adapter to get the state dict from. If no name is passed, the active adapter is used.
state_dict (nested dictionary of `torch.Tensor`, *optional*)
The state dictionary of the model. Will default to `self.state_dict()`, but can be used if special
precautions need to be taken when recovering the state dictionary of a model (like when using model
parallelism).
""" """
check_peft_version(min_version=MIN_PEFT_VERSION) check_peft_version(min_version=MIN_PEFT_VERSION)
@@ -468,7 +472,7 @@ class PeftAdapterMixin:
if adapter_name is None: if adapter_name is None:
adapter_name = self.active_adapters()[0] adapter_name = self.active_adapters()[0]
adapter_state_dict = get_peft_model_state_dict(self, adapter_name=adapter_name) adapter_state_dict = get_peft_model_state_dict(self, state_dict=state_dict, adapter_name=adapter_name)
return adapter_state_dict return adapter_state_dict
def _dispatch_accelerate_model( def _dispatch_accelerate_model(

View File

@@ -3359,7 +3359,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
logger.info( logger.info(
"Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved." "Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved."
) )
state_dict = model_to_save.get_adapter_state_dict() state_dict = model_to_save.get_adapter_state_dict(state_dict=state_dict)
if save_peft_format: if save_peft_format:
logger.info( logger.info(