[PEFT] Fix PEFT multi adapters support (#26407)

* fix PEFT multi adapters support

* refactor a bit

* save pretrained + BC + added tests

* Update src/transformers/integrations/peft.py

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* add more tests

* add suggestion

* final changes

* adapt a bit

* fixup

* Update src/transformers/integrations/peft.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* adapt from suggestions

---------

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Younes Belkada
2023-09-27 16:45:31 +02:00
committed by GitHub
parent 946bac798c
commit 3ca18d6d09
3 changed files with 76 additions and 11 deletions

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from ..utils import ( from ..utils import (
check_peft_version, check_peft_version,
@@ -245,7 +245,7 @@ class PeftAdapterMixin:
self.set_adapter(adapter_name) self.set_adapter(adapter_name)
def set_adapter(self, adapter_name: str) -> None: def set_adapter(self, adapter_name: Union[List[str], str]) -> None:
""" """
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft official documentation: https://huggingface.co/docs/peft
@@ -253,12 +253,19 @@ class PeftAdapterMixin:
Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters. Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters.
Args: Args:
adapter_name (`str`): adapter_name (`Union[List[str], str]`):
The name of the adapter to set. The name of the adapter to set. Can be also a list of strings to set multiple adapters.
""" """
check_peft_version(min_version=MIN_PEFT_VERSION) check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded: if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.") raise ValueError("No adapter loaded. Please load an adapter first.")
elif isinstance(adapter_name, list):
missing = set(adapter_name) - set(self.peft_config)
if len(missing) > 0:
raise ValueError(
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
f" current loaded adapters are: {list(self.peft_config.keys())}"
)
elif adapter_name not in self.peft_config: elif adapter_name not in self.peft_config:
raise ValueError( raise ValueError(
f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}" f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}"
@@ -270,6 +277,10 @@ class PeftAdapterMixin:
for _, module in self.named_modules(): for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer): if isinstance(module, BaseTunerLayer):
# For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name module.active_adapter = adapter_name
_adapters_has_been_set = True _adapters_has_been_set = True
@@ -294,6 +305,10 @@ class PeftAdapterMixin:
for _, module in self.named_modules(): for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer): if isinstance(module, BaseTunerLayer):
# The recent version of PEFT need to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
else:
module.disable_adapters = True module.disable_adapters = True
def enable_adapters(self) -> None: def enable_adapters(self) -> None:
@@ -312,14 +327,22 @@ class PeftAdapterMixin:
for _, module in self.named_modules(): for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer): if isinstance(module, BaseTunerLayer):
# The recent version of PEFT need to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=True)
else:
module.disable_adapters = False module.disable_adapters = False
def active_adapter(self) -> str: def active_adapters(self) -> List[str]:
""" """
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft official documentation: https://huggingface.co/docs/peft
Gets the current active adapter of the model. Gets the current active adapters of the model. In case of multi-adapter inference (combining multiple adapters
for inference) returns the list of all active adapters so that users can deal with them accordingly.
For previous PEFT versions (that does not support multi-adapter inference), `module.active_adapter` will return
a single string.
""" """
check_peft_version(min_version=MIN_PEFT_VERSION) check_peft_version(min_version=MIN_PEFT_VERSION)
@@ -333,7 +356,21 @@ class PeftAdapterMixin:
for _, module in self.named_modules(): for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer): if isinstance(module, BaseTunerLayer):
return module.active_adapter active_adapters = module.active_adapter
break
# For previous PEFT versions
if isinstance(active_adapters, str):
active_adapters = [active_adapters]
return active_adapters
def active_adapter(self) -> str:
logger.warning(
"The `active_adapter` method is deprecated and will be removed in a future version. ", FutureWarning
)
return self.active_adapters()[0]
def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict: def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
""" """

View File

@@ -2006,7 +2006,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
peft_state_dict[f"base_model.model.{key}"] = value peft_state_dict[f"base_model.model.{key}"] = value
state_dict = peft_state_dict state_dict = peft_state_dict
current_peft_config = self.peft_config[self.active_adapter()] active_adapter = self.active_adapters()
if len(active_adapter) > 1:
raise ValueError(
"Multiple active adapters detected, saving multiple active adapters is not supported yet. You can save adapters separately one by one "
"by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`"
)
active_adapter = active_adapter[0]
current_peft_config = self.peft_config[active_adapter]
current_peft_config.save_pretrained(save_directory) current_peft_config.save_pretrained(save_directory)
# Save the model # Save the model

View File

@@ -265,9 +265,11 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
_ = model.generate(input_ids=dummy_input) _ = model.generate(input_ids=dummy_input)
model.set_adapter("default") model.set_adapter("default")
self.assertTrue(model.active_adapters() == ["default"])
self.assertTrue(model.active_adapter() == "default") self.assertTrue(model.active_adapter() == "default")
model.set_adapter("adapter-2") model.set_adapter("adapter-2")
self.assertTrue(model.active_adapters() == ["adapter-2"])
self.assertTrue(model.active_adapter() == "adapter-2") self.assertTrue(model.active_adapter() == "adapter-2")
# Logits comparison # Logits comparison
@@ -276,6 +278,23 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
) )
self.assertFalse(torch.allclose(logits_original_model, logits_adapter_2.logits, atol=1e-6, rtol=1e-6)) self.assertFalse(torch.allclose(logits_original_model, logits_adapter_2.logits, atol=1e-6, rtol=1e-6))
model.set_adapter(["adapter-2", "default"])
self.assertTrue(model.active_adapters() == ["adapter-2", "default"])
self.assertTrue(model.active_adapter() == "adapter-2")
logits_adapter_mixed = model(dummy_input)
self.assertFalse(
torch.allclose(logits_adapter_1.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
)
self.assertFalse(
torch.allclose(logits_adapter_2.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
)
# multi active adapter saving not supported
with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
@require_torch_gpu @require_torch_gpu
def test_peft_from_pretrained_kwargs(self): def test_peft_from_pretrained_kwargs(self):
""" """