[PEFT] Fix PEFT multi adapters support (#26407)
* fix PEFT multi adapters support * refactor a bit * save pretrained + BC + added tests * Update src/transformers/integrations/peft.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * add more tests * add suggestion * final changes * adapt a bit * fixup * Update src/transformers/integrations/peft.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * adapt from suggestions --------- Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import inspect
|
import inspect
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
check_peft_version,
|
check_peft_version,
|
||||||
@@ -245,7 +245,7 @@ class PeftAdapterMixin:
|
|||||||
|
|
||||||
self.set_adapter(adapter_name)
|
self.set_adapter(adapter_name)
|
||||||
|
|
||||||
def set_adapter(self, adapter_name: str) -> None:
|
def set_adapter(self, adapter_name: Union[List[str], str]) -> None:
|
||||||
"""
|
"""
|
||||||
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
||||||
official documentation: https://huggingface.co/docs/peft
|
official documentation: https://huggingface.co/docs/peft
|
||||||
@@ -253,12 +253,19 @@ class PeftAdapterMixin:
|
|||||||
Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters.
|
Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
adapter_name (`str`):
|
adapter_name (`Union[List[str], str]`):
|
||||||
The name of the adapter to set.
|
The name of the adapter to set. Can be also a list of strings to set multiple adapters.
|
||||||
"""
|
"""
|
||||||
check_peft_version(min_version=MIN_PEFT_VERSION)
|
check_peft_version(min_version=MIN_PEFT_VERSION)
|
||||||
if not self._hf_peft_config_loaded:
|
if not self._hf_peft_config_loaded:
|
||||||
raise ValueError("No adapter loaded. Please load an adapter first.")
|
raise ValueError("No adapter loaded. Please load an adapter first.")
|
||||||
|
elif isinstance(adapter_name, list):
|
||||||
|
missing = set(adapter_name) - set(self.peft_config)
|
||||||
|
if len(missing) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
|
||||||
|
f" current loaded adapters are: {list(self.peft_config.keys())}"
|
||||||
|
)
|
||||||
elif adapter_name not in self.peft_config:
|
elif adapter_name not in self.peft_config:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}"
|
f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}"
|
||||||
@@ -270,7 +277,11 @@ class PeftAdapterMixin:
|
|||||||
|
|
||||||
for _, module in self.named_modules():
|
for _, module in self.named_modules():
|
||||||
if isinstance(module, BaseTunerLayer):
|
if isinstance(module, BaseTunerLayer):
|
||||||
module.active_adapter = adapter_name
|
# For backward compatbility with previous PEFT versions
|
||||||
|
if hasattr(module, "set_adapter"):
|
||||||
|
module.set_adapter(adapter_name)
|
||||||
|
else:
|
||||||
|
module.active_adapter = adapter_name
|
||||||
_adapters_has_been_set = True
|
_adapters_has_been_set = True
|
||||||
|
|
||||||
if not _adapters_has_been_set:
|
if not _adapters_has_been_set:
|
||||||
@@ -294,7 +305,11 @@ class PeftAdapterMixin:
|
|||||||
|
|
||||||
for _, module in self.named_modules():
|
for _, module in self.named_modules():
|
||||||
if isinstance(module, BaseTunerLayer):
|
if isinstance(module, BaseTunerLayer):
|
||||||
module.disable_adapters = True
|
# The recent version of PEFT need to call `enable_adapters` instead
|
||||||
|
if hasattr(module, "enable_adapters"):
|
||||||
|
module.enable_adapters(enabled=False)
|
||||||
|
else:
|
||||||
|
module.disable_adapters = True
|
||||||
|
|
||||||
def enable_adapters(self) -> None:
|
def enable_adapters(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -312,14 +327,22 @@ class PeftAdapterMixin:
|
|||||||
|
|
||||||
for _, module in self.named_modules():
|
for _, module in self.named_modules():
|
||||||
if isinstance(module, BaseTunerLayer):
|
if isinstance(module, BaseTunerLayer):
|
||||||
module.disable_adapters = False
|
# The recent version of PEFT need to call `enable_adapters` instead
|
||||||
|
if hasattr(module, "enable_adapters"):
|
||||||
|
module.enable_adapters(enabled=True)
|
||||||
|
else:
|
||||||
|
module.disable_adapters = False
|
||||||
|
|
||||||
def active_adapter(self) -> str:
|
def active_adapters(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
||||||
official documentation: https://huggingface.co/docs/peft
|
official documentation: https://huggingface.co/docs/peft
|
||||||
|
|
||||||
Gets the current active adapter of the model.
|
Gets the current active adapters of the model. In case of multi-adapter inference (combining multiple adapters
|
||||||
|
for inference) returns the list of all active adapters so that users can deal with them accordingly.
|
||||||
|
|
||||||
|
For previous PEFT versions (that does not support multi-adapter inference), `module.active_adapter` will return
|
||||||
|
a single string.
|
||||||
"""
|
"""
|
||||||
check_peft_version(min_version=MIN_PEFT_VERSION)
|
check_peft_version(min_version=MIN_PEFT_VERSION)
|
||||||
|
|
||||||
@@ -333,7 +356,21 @@ class PeftAdapterMixin:
|
|||||||
|
|
||||||
for _, module in self.named_modules():
|
for _, module in self.named_modules():
|
||||||
if isinstance(module, BaseTunerLayer):
|
if isinstance(module, BaseTunerLayer):
|
||||||
return module.active_adapter
|
active_adapters = module.active_adapter
|
||||||
|
break
|
||||||
|
|
||||||
|
# For previous PEFT versions
|
||||||
|
if isinstance(active_adapters, str):
|
||||||
|
active_adapters = [active_adapters]
|
||||||
|
|
||||||
|
return active_adapters
|
||||||
|
|
||||||
|
def active_adapter(self) -> str:
|
||||||
|
logger.warning(
|
||||||
|
"The `active_adapter` method is deprecated and will be removed in a future version. ", FutureWarning
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.active_adapters()[0]
|
||||||
|
|
||||||
def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
|
def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2006,7 +2006,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
peft_state_dict[f"base_model.model.{key}"] = value
|
peft_state_dict[f"base_model.model.{key}"] = value
|
||||||
state_dict = peft_state_dict
|
state_dict = peft_state_dict
|
||||||
|
|
||||||
current_peft_config = self.peft_config[self.active_adapter()]
|
active_adapter = self.active_adapters()
|
||||||
|
|
||||||
|
if len(active_adapter) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Multiple active adapters detected, saving multiple active adapters is not supported yet. You can save adapters separately one by one "
|
||||||
|
"by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`"
|
||||||
|
)
|
||||||
|
active_adapter = active_adapter[0]
|
||||||
|
|
||||||
|
current_peft_config = self.peft_config[active_adapter]
|
||||||
current_peft_config.save_pretrained(save_directory)
|
current_peft_config.save_pretrained(save_directory)
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
|
|||||||
@@ -265,9 +265,11 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
_ = model.generate(input_ids=dummy_input)
|
_ = model.generate(input_ids=dummy_input)
|
||||||
|
|
||||||
model.set_adapter("default")
|
model.set_adapter("default")
|
||||||
|
self.assertTrue(model.active_adapters() == ["default"])
|
||||||
self.assertTrue(model.active_adapter() == "default")
|
self.assertTrue(model.active_adapter() == "default")
|
||||||
|
|
||||||
model.set_adapter("adapter-2")
|
model.set_adapter("adapter-2")
|
||||||
|
self.assertTrue(model.active_adapters() == ["adapter-2"])
|
||||||
self.assertTrue(model.active_adapter() == "adapter-2")
|
self.assertTrue(model.active_adapter() == "adapter-2")
|
||||||
|
|
||||||
# Logits comparison
|
# Logits comparison
|
||||||
@@ -276,6 +278,23 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
)
|
)
|
||||||
self.assertFalse(torch.allclose(logits_original_model, logits_adapter_2.logits, atol=1e-6, rtol=1e-6))
|
self.assertFalse(torch.allclose(logits_original_model, logits_adapter_2.logits, atol=1e-6, rtol=1e-6))
|
||||||
|
|
||||||
|
model.set_adapter(["adapter-2", "default"])
|
||||||
|
self.assertTrue(model.active_adapters() == ["adapter-2", "default"])
|
||||||
|
self.assertTrue(model.active_adapter() == "adapter-2")
|
||||||
|
|
||||||
|
logits_adapter_mixed = model(dummy_input)
|
||||||
|
self.assertFalse(
|
||||||
|
torch.allclose(logits_adapter_1.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertFalse(
|
||||||
|
torch.allclose(logits_adapter_2.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
|
||||||
|
)
|
||||||
|
|
||||||
|
# multi active adapter saving not supported
|
||||||
|
with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
def test_peft_from_pretrained_kwargs(self):
|
def test_peft_from_pretrained_kwargs(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user