From 3ca18d6d09ee0d1610a400ead6f6041394f66421 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 27 Sep 2023 16:45:31 +0200 Subject: [PATCH] [`PEFT`] Fix PEFT multi adapters support (#26407) * fix PEFT multi adapters support * refactor a bit * save pretrained + BC + added tests * Update src/transformers/integrations/peft.py Co-authored-by: Benjamin Bossan * add more tests * add suggestion * final changes * adapt a bit * fixup * Update src/transformers/integrations/peft.py Co-authored-by: Patrick von Platen * adapt from suggestions --------- Co-authored-by: Benjamin Bossan Co-authored-by: Patrick von Platen --- src/transformers/integrations/peft.py | 57 +++++++++++++++---- src/transformers/modeling_utils.py | 11 +++- .../peft_integration/test_peft_integration.py | 19 +++++++ 3 files changed, 76 insertions(+), 11 deletions(-) diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index 5a5be96cf3..fb0b122a23 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from ..utils import ( check_peft_version, @@ -245,7 +245,7 @@ class PeftAdapterMixin: self.set_adapter(adapter_name) - def set_adapter(self, adapter_name: str) -> None: + def set_adapter(self, adapter_name: Union[List[str], str]) -> None: """ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT official documentation: https://huggingface.co/docs/peft @@ -253,12 +253,19 @@ class PeftAdapterMixin: Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters. Args: - adapter_name (`str`): - The name of the adapter to set. + adapter_name (`Union[List[str], str]`): + The name of the adapter to set. Can be also a list of strings to set multiple adapters. """ check_peft_version(min_version=MIN_PEFT_VERSION) if not self._hf_peft_config_loaded: raise ValueError("No adapter loaded. Please load an adapter first.") + elif isinstance(adapter_name, list): + missing = set(adapter_name) - set(self.peft_config) + if len(missing) > 0: + raise ValueError( + f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)." + f" current loaded adapters are: {list(self.peft_config.keys())}" + ) elif adapter_name not in self.peft_config: raise ValueError( f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}" @@ -270,7 +277,11 @@ class PeftAdapterMixin: for _, module in self.named_modules(): if isinstance(module, BaseTunerLayer): - module.active_adapter = adapter_name + # For backward compatbility with previous PEFT versions + if hasattr(module, "set_adapter"): + module.set_adapter(adapter_name) + else: + module.active_adapter = adapter_name _adapters_has_been_set = True if not _adapters_has_been_set: @@ -294,7 +305,11 @@ class PeftAdapterMixin: for _, module in self.named_modules(): if isinstance(module, BaseTunerLayer): - module.disable_adapters = True + # The recent version of PEFT need to call `enable_adapters` instead + if hasattr(module, "enable_adapters"): + module.enable_adapters(enabled=False) + else: + module.disable_adapters = True def enable_adapters(self) -> None: """ @@ -312,14 +327,22 @@ class PeftAdapterMixin: for _, module in self.named_modules(): if isinstance(module, BaseTunerLayer): - module.disable_adapters = False + # The recent version of PEFT need to call `enable_adapters` instead + if hasattr(module, "enable_adapters"): + module.enable_adapters(enabled=True) + else: + module.disable_adapters = False - def active_adapter(self) -> str: + def active_adapters(self) -> List[str]: """ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT official documentation: https://huggingface.co/docs/peft - Gets the current active adapter of the model. + Gets the current active adapters of the model. In case of multi-adapter inference (combining multiple adapters + for inference) returns the list of all active adapters so that users can deal with them accordingly. + + For previous PEFT versions (that does not support multi-adapter inference), `module.active_adapter` will return + a single string. """ check_peft_version(min_version=MIN_PEFT_VERSION) @@ -333,7 +356,21 @@ class PeftAdapterMixin: for _, module in self.named_modules(): if isinstance(module, BaseTunerLayer): - return module.active_adapter + active_adapters = module.active_adapter + break + + # For previous PEFT versions + if isinstance(active_adapters, str): + active_adapters = [active_adapters] + + return active_adapters + + def active_adapter(self) -> str: + logger.warning( + "The `active_adapter` method is deprecated and will be removed in a future version. ", FutureWarning + ) + + return self.active_adapters()[0] def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict: """ diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ad1fd846cb..eeda7f7d48 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2006,7 +2006,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix peft_state_dict[f"base_model.model.{key}"] = value state_dict = peft_state_dict - current_peft_config = self.peft_config[self.active_adapter()] + active_adapter = self.active_adapters() + + if len(active_adapter) > 1: + raise ValueError( + "Multiple active adapters detected, saving multiple active adapters is not supported yet. You can save adapters separately one by one " + "by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`" + ) + active_adapter = active_adapter[0] + + current_peft_config = self.peft_config[active_adapter] current_peft_config.save_pretrained(save_directory) # Save the model diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index efa8d68705..ae8cbe5b4d 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -265,9 +265,11 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): _ = model.generate(input_ids=dummy_input) model.set_adapter("default") + self.assertTrue(model.active_adapters() == ["default"]) self.assertTrue(model.active_adapter() == "default") model.set_adapter("adapter-2") + self.assertTrue(model.active_adapters() == ["adapter-2"]) self.assertTrue(model.active_adapter() == "adapter-2") # Logits comparison @@ -276,6 +278,23 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): ) self.assertFalse(torch.allclose(logits_original_model, logits_adapter_2.logits, atol=1e-6, rtol=1e-6)) + model.set_adapter(["adapter-2", "default"]) + self.assertTrue(model.active_adapters() == ["adapter-2", "default"]) + self.assertTrue(model.active_adapter() == "adapter-2") + + logits_adapter_mixed = model(dummy_input) + self.assertFalse( + torch.allclose(logits_adapter_1.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6) + ) + + self.assertFalse( + torch.allclose(logits_adapter_2.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6) + ) + + # multi active adapter saving not supported + with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + @require_torch_gpu def test_peft_from_pretrained_kwargs(self): """