From 0a55d9f7376f72ad3ff296d4249840021b03bcc4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 15 Sep 2023 18:22:01 +0200 Subject: [PATCH] [PEFT] Allow PEFT model dict to be loaded (#25721) * Allow PEFT model dict to be loaded * make style * make style * Apply suggestions from code review * address comments * fixup * final change * added tests * fix test * better logic for handling if adapter has been loaded * Update tests/peft_integration/test_peft_integration.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: younesbelkada Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/integrations/peft.py | 66 ++++++++++++------- .../peft_integration/test_peft_integration.py | 32 +++++++++ 2 files changed, 76 insertions(+), 22 deletions(-) diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index 0c743b9f9b..5a5be96cf3 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import Optional +from typing import TYPE_CHECKING, Any, Dict, Optional from ..utils import ( check_peft_version, find_adapter_config_file, is_accelerate_available, is_peft_available, + is_torch_available, logging, ) @@ -30,6 +31,11 @@ if is_accelerate_available(): # Minimum PEFT version supported for the integration MIN_PEFT_VERSION = "0.5.0" +if TYPE_CHECKING: + if is_torch_available(): + import torch + + logger = logging.get_logger(__name__) @@ -61,7 +67,7 @@ class PeftAdapterMixin: def load_adapter( self, - peft_model_id: str, + peft_model_id: Optional[str] = None, adapter_name: Optional[str] = None, revision: Optional[str] = None, token: Optional[str] = None, @@ -69,6 +75,8 @@ class PeftAdapterMixin: max_memory: Optional[str] = None, offload_folder: Optional[str] = None, offload_index: Optional[int] = None, + peft_config: Dict[str, Any] = None, + adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None, ) -> None: """ Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we @@ -77,7 +85,7 @@ class PeftAdapterMixin: Requires peft as a backend to load the adapter weights. Args: - peft_model_id (`str`): + peft_model_id (`str`, *optional*): The identifier of the model to look for on the Hub, or a local path to the saved adapter config file and adapter weights. adapter_name (`str`, *optional*): @@ -114,6 +122,12 @@ class PeftAdapterMixin: If the `device_map` contains any value `"disk"`, the folder where we will offload weights. offload_index (`int`, `optional`): `offload_index` argument to be passed to `accelerate.dispatch_model` method. + peft_config (`Dict[str, Any]`, *optional*): + The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts + methods. This argument is used in case users directly pass PEFT state dicts + adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*): + The state dict of the adapter to load. This argument is used in case users directly pass PEFT state + dicts """ check_peft_version(min_version=MIN_PEFT_VERSION) @@ -122,33 +136,41 @@ class PeftAdapterMixin: from peft import PeftConfig, inject_adapter_in_model, load_peft_weights from peft.utils import set_peft_model_state_dict - if not self._hf_peft_config_loaded: - self._hf_peft_config_loaded = True - elif adapter_name in self.peft_config: + if self._hf_peft_config_loaded and adapter_name in self.peft_config: raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") - adapter_config_file = find_adapter_config_file( - peft_model_id, - revision=revision, - token=token, - ) - - if adapter_config_file is None: + if peft_model_id is None and (adapter_state_dict is None and peft_config is None): raise ValueError( - f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the " - "adapter model." + "You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter." ) - loaded_peft_config = PeftConfig.from_pretrained( - peft_model_id, - revision=revision, - use_auth_token=token, - ) + if peft_config is None: + adapter_config_file = find_adapter_config_file( + peft_model_id, + revision=revision, + token=token, + ) + + if adapter_config_file is None: + raise ValueError( + f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the " + "adapter model." + ) + + peft_config = PeftConfig.from_pretrained( + peft_model_id, + revision=revision, + use_auth_token=token, + ) # Create and add fresh new adapters into the model. - inject_adapter_in_model(loaded_peft_config, self, adapter_name) + inject_adapter_in_model(peft_config, self, adapter_name) - adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token) + if not self._hf_peft_config_loaded: + self._hf_peft_config_loaded = True + + if peft_model_id is not None: + adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token) # We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility processed_adapter_state_dict = {} diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index b238ce25cb..efa8d68705 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -16,6 +16,8 @@ import os import tempfile import unittest +from huggingface_hub import hf_hub_download + from transformers import AutoModelForCausalLM, OPTForCausalLM from transformers.testing_utils import require_peft, require_torch, require_torch_gpu, slow, torch_device from transformers.utils import is_torch_available @@ -300,3 +302,33 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): for model_id in self.peft_test_model_ids: pipe = pipeline("text-generation", model_id) _ = pipe("Hello") + + def test_peft_add_adapter_with_state_dict(self): + """ + Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if + add_adapter works as expected with a state_dict being passed. + """ + from peft import LoraConfig + + dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device) + + for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids): + for transformers_class in self.transformers_test_model_classes: + model = transformers_class.from_pretrained(model_id).to(torch_device) + + peft_config = LoraConfig(init_lora_weights=False) + + with self.assertRaises(ValueError): + model.load_adapter(peft_model_id=None) + + state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin") + + dummy_state_dict = torch.load(state_dict_path) + + model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=peft_config) + with self.assertRaises(ValueError): + model.load_adapter(model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=None)) + self.assertTrue(self._check_lora_correctly_converted(model)) + + # dummy generation + _ = model.generate(input_ids=dummy_input)