From 63380b77d46575957d126624db1ac956a4167b9b Mon Sep 17 00:00:00 2001 From: Artem Kudisov <70983678+phos-phophy@users.noreply.github.com> Date: Thu, 20 Mar 2025 13:54:59 +0300 Subject: [PATCH] Pass state dict (#35234) * Pass state_dict argument to get_peft_model_state_dict * Style fix * Change arguments order --- src/transformers/integrations/peft.py | 8 ++++++-- src/transformers/modeling_utils.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index 791528a629..6aa3b137b1 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -446,7 +446,7 @@ class PeftAdapterMixin: return self.active_adapters()[0] - def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict: + def get_adapter_state_dict(self, adapter_name: Optional[str] = None, state_dict: Optional[dict] = None) -> dict: """ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT official documentation: https://huggingface.co/docs/peft @@ -457,6 +457,10 @@ class PeftAdapterMixin: Args: adapter_name (`str`, *optional*): The name of the adapter to get the state dict from. If no name is passed, the active adapter is used. + state_dict (nested dictionary of `torch.Tensor`, *optional*) + The state dictionary of the model. Will default to `self.state_dict()`, but can be used if special + precautions need to be taken when recovering the state dictionary of a model (like when using model + parallelism). """ check_peft_version(min_version=MIN_PEFT_VERSION) @@ -468,7 +472,7 @@ class PeftAdapterMixin: if adapter_name is None: adapter_name = self.active_adapters()[0] - adapter_state_dict = get_peft_model_state_dict(self, adapter_name=adapter_name) + adapter_state_dict = get_peft_model_state_dict(self, state_dict=state_dict, adapter_name=adapter_name) return adapter_state_dict def _dispatch_accelerate_model( diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4158c82b40..120af2b842 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3359,7 +3359,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix logger.info( "Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved." ) - state_dict = model_to_save.get_adapter_state_dict() + state_dict = model_to_save.get_adapter_state_dict(state_dict=state_dict) if save_peft_format: logger.info(