From 38e96324ef63c79cbe36fd9d167adb8aeffe5484 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 28 Sep 2023 11:13:03 +0200 Subject: [PATCH] =?UTF-8?q?[`PEFT`]=C2=A0introducing=20`adapter=5Fkwargs`?= =?UTF-8?q?=20for=20loading=20adapters=20from=20different=20Hub=20location?= =?UTF-8?q?=20(`subfolder`,=20`revision`)=20than=20the=20base=20model=20(#?= =?UTF-8?q?26270)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * make use of adapter_revision * v1 adapter kwargs * fix CI * fix CI * fix CI * fixup * add BC * Update src/transformers/integrations/peft.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup * change it to error * Update src/transformers/modeling_utils.py * Update src/transformers/modeling_utils.py * fixup * change * Update src/transformers/integrations/peft.py --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/integrations/peft.py | 21 ++++++++++++--- src/transformers/modeling_flax_utils.py | 3 +++ src/transformers/modeling_tf_utils.py | 3 +++ src/transformers/modeling_utils.py | 11 +++++--- src/transformers/models/auto/auto_factory.py | 12 +++++++-- .../peft_integration/test_peft_integration.py | 27 +++++++++++++++++++ 6 files changed, 68 insertions(+), 9 deletions(-) diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index fb0b122a23..aa4fc083df 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -77,6 +77,7 @@ class PeftAdapterMixin: offload_index: Optional[int] = None, peft_config: Dict[str, Any] = None, adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None, + adapter_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """ Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we @@ -128,10 +129,15 @@ class PeftAdapterMixin: adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*): The state dict of the adapter to load. This argument is used in case users directly pass PEFT state dicts + adapter_kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and + `find_adapter_config_file` method. """ check_peft_version(min_version=MIN_PEFT_VERSION) adapter_name = adapter_name if adapter_name is not None else "default" + if adapter_kwargs is None: + adapter_kwargs = {} from peft import PeftConfig, inject_adapter_in_model, load_peft_weights from peft.utils import set_peft_model_state_dict @@ -144,11 +150,20 @@ class PeftAdapterMixin: "You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter." ) + # We keep `revision` in the signature for backward compatibility + if revision is not None and "revision" not in adapter_kwargs: + adapter_kwargs["revision"] = revision + elif revision is not None and "revision" in adapter_kwargs and revision != adapter_kwargs["revision"]: + logger.error( + "You passed a `revision` argument both in `adapter_kwargs` and as a standalone argument. " + "The one in `adapter_kwargs` will be used." + ) + if peft_config is None: adapter_config_file = find_adapter_config_file( peft_model_id, - revision=revision, token=token, + **adapter_kwargs, ) if adapter_config_file is None: @@ -159,8 +174,8 @@ class PeftAdapterMixin: peft_config = PeftConfig.from_pretrained( peft_model_id, - revision=revision, use_auth_token=token, + **adapter_kwargs, ) # Create and add fresh new adapters into the model. @@ -170,7 +185,7 @@ class PeftAdapterMixin: self._hf_peft_config_loaded = True if peft_model_id is not None: - adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token) + adapter_state_dict = load_peft_weights(peft_model_id, use_auth_token=token, **adapter_kwargs) # We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility processed_adapter_state_dict = {} diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index c75086f09f..64a42609fc 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -623,6 +623,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): subfolder = kwargs.pop("subfolder", "") commit_hash = kwargs.pop("_commit_hash", None) + # Not relevant for Flax Models + _ = kwargs.pop("adapter_kwargs", None) + if use_auth_token is not None: warnings.warn( "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 4fbd984b16..6505a2ec6d 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -2645,6 +2645,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu commit_hash = kwargs.pop("_commit_hash", None) tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None) + # Not relevant for TF models + _ = kwargs.pop("adapter_kwargs", None) + if use_auth_token is not None: warnings.warn( "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index eeda7f7d48..eab0e6f2a8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2463,7 +2463,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix subfolder = kwargs.pop("subfolder", "") commit_hash = kwargs.pop("_commit_hash", None) variant = kwargs.pop("variant", None) - _adapter_model_path = kwargs.pop("_adapter_model_path", None) + adapter_kwargs = kwargs.pop("adapter_kwargs", {}) adapter_name = kwargs.pop("adapter_name", "default") use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) @@ -2516,6 +2516,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix commit_hash = getattr(config, "_commit_hash", None) if is_peft_available(): + _adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None) + if _adapter_model_path is None: _adapter_model_path = find_adapter_config_file( pretrained_model_name_or_path, @@ -2525,14 +2527,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix proxies=proxies, local_files_only=local_files_only, token=token, - revision=revision, - subfolder=subfolder, _commit_hash=commit_hash, + **adapter_kwargs, ) if _adapter_model_path is not None and os.path.isfile(_adapter_model_path): with open(_adapter_model_path, "r", encoding="utf-8") as f: _adapter_model_path = pretrained_model_name_or_path pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"] + else: + _adapter_model_path = None # change device_map into a map if we passed an int, a str or a torch.device if isinstance(device_map, torch.device): @@ -3371,8 +3374,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix model.load_adapter( _adapter_model_path, adapter_name=adapter_name, - revision=revision, token=token, + adapter_kwargs=adapter_kwargs, ) if output_loading_info: diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index f9f1abdd5a..5ee9029eb6 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -469,6 +469,7 @@ class _BaseAutoModelClass: hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} code_revision = kwargs.pop("code_revision", None) commit_hash = kwargs.pop("_commit_hash", None) + adapter_kwargs = kwargs.pop("adapter_kwargs", None) revision = hub_kwargs.pop("revision", None) hub_kwargs["revision"] = sanitize_code_revision(pretrained_model_name_or_path, revision, trust_remote_code) @@ -503,15 +504,18 @@ class _BaseAutoModelClass: commit_hash = getattr(config, "_commit_hash", None) if is_peft_available(): + if adapter_kwargs is None: + adapter_kwargs = {} + maybe_adapter_path = find_adapter_config_file( - pretrained_model_name_or_path, _commit_hash=commit_hash, **hub_kwargs + pretrained_model_name_or_path, _commit_hash=commit_hash, **adapter_kwargs ) if maybe_adapter_path is not None: with open(maybe_adapter_path, "r", encoding="utf-8") as f: adapter_config = json.load(f) - kwargs["_adapter_model_path"] = pretrained_model_name_or_path + adapter_kwargs["_adapter_model_path"] = pretrained_model_name_or_path pretrained_model_name_or_path = adapter_config["base_model_name_or_path"] if not isinstance(config, PretrainedConfig): @@ -545,6 +549,10 @@ class _BaseAutoModelClass: trust_remote_code = resolve_trust_remote_code( trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code ) + + # Set the adapter kwargs + kwargs["adapter_kwargs"] = adapter_kwargs + if has_remote_code and trust_remote_code: class_ref = config.auto_map[cls.__name__] model_class = get_class_from_dynamic_module( diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index ae8cbe5b4d..dbd0976dd4 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -351,3 +351,30 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): # dummy generation _ = model.generate(input_ids=dummy_input) + + def test_peft_from_pretrained_hub_kwargs(self): + """ + Tests different combinations of PEFT model + from_pretrained + hub kwargs + """ + peft_model_id = "peft-internal-testing/tiny-opt-lora-revision" + + # This should not work + with self.assertRaises(OSError): + _ = AutoModelForCausalLM.from_pretrained(peft_model_id) + + adapter_kwargs = {"revision": "test"} + + # This should work + model = AutoModelForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs) + self.assertTrue(self._check_lora_correctly_converted(model)) + + model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs) + self.assertTrue(self._check_lora_correctly_converted(model)) + + adapter_kwargs = {"revision": "main", "subfolder": "test_subfolder"} + + model = AutoModelForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs) + self.assertTrue(self._check_lora_correctly_converted(model)) + + model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs) + self.assertTrue(self._check_lora_correctly_converted(model))