[Peft] modules_to_save support for peft integration (#27466)
* `modules_to_save` support for peft integration * Update docs/source/en/peft.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * slightly elaborate test --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -98,7 +98,7 @@ You can use [`~peft.PeftModel.add_adapter`] to add a new adapter to a model with
|
|||||||
|
|
||||||
```py
|
```py
|
||||||
from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
|
||||||
from peft import PeftConfig
|
from peft import LoraConfig
|
||||||
|
|
||||||
model_id = "facebook/opt-350m"
|
model_id = "facebook/opt-350m"
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||||
@@ -208,6 +208,26 @@ model.save_pretrained(save_dir)
|
|||||||
model = AutoModelForCausalLM.from_pretrained(save_dir)
|
model = AutoModelForCausalLM.from_pretrained(save_dir)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Add additional trainable layers to a PEFT adapter
|
||||||
|
|
||||||
|
You can also fine-tune additional trainable adapters on top of a model that has adapters attached by passing `modules_to_save` in your PEFT config. For example, if you want to also fine-tune the lm_head on top of a model with a LoRA adapter:
|
||||||
|
|
||||||
|
```py
|
||||||
|
from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
|
model_id = "facebook/opt-350m"
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||||
|
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
target_modules=["q_proj", "k_proj"],
|
||||||
|
modules_to_save=["lm_head"],
|
||||||
|
)
|
||||||
|
|
||||||
|
model.add_adapter(lora_config)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
<!--
|
<!--
|
||||||
TODO: (@younesbelkada @stevhliu)
|
TODO: (@younesbelkada @stevhliu)
|
||||||
- Link to PEFT docs for further details
|
- Link to PEFT docs for further details
|
||||||
|
|||||||
@@ -292,11 +292,12 @@ class PeftAdapterMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||||
|
from peft.utils import ModulesToSaveWrapper
|
||||||
|
|
||||||
_adapters_has_been_set = False
|
_adapters_has_been_set = False
|
||||||
|
|
||||||
for _, module in self.named_modules():
|
for _, module in self.named_modules():
|
||||||
if isinstance(module, BaseTunerLayer):
|
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
|
||||||
# For backward compatbility with previous PEFT versions
|
# For backward compatbility with previous PEFT versions
|
||||||
if hasattr(module, "set_adapter"):
|
if hasattr(module, "set_adapter"):
|
||||||
module.set_adapter(adapter_name)
|
module.set_adapter(adapter_name)
|
||||||
@@ -322,9 +323,10 @@ class PeftAdapterMixin:
|
|||||||
raise ValueError("No adapter loaded. Please load an adapter first.")
|
raise ValueError("No adapter loaded. Please load an adapter first.")
|
||||||
|
|
||||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||||
|
from peft.utils import ModulesToSaveWrapper
|
||||||
|
|
||||||
for _, module in self.named_modules():
|
for _, module in self.named_modules():
|
||||||
if isinstance(module, BaseTunerLayer):
|
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
|
||||||
# The recent version of PEFT need to call `enable_adapters` instead
|
# The recent version of PEFT need to call `enable_adapters` instead
|
||||||
if hasattr(module, "enable_adapters"):
|
if hasattr(module, "enable_adapters"):
|
||||||
module.enable_adapters(enabled=False)
|
module.enable_adapters(enabled=False)
|
||||||
|
|||||||
@@ -182,6 +182,44 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
model_from_pretrained = transformers_class.from_pretrained(tmpdirname).to(torch_device)
|
model_from_pretrained = transformers_class.from_pretrained(tmpdirname).to(torch_device)
|
||||||
self.assertTrue(self._check_lora_correctly_converted(model_from_pretrained))
|
self.assertTrue(self._check_lora_correctly_converted(model_from_pretrained))
|
||||||
|
|
||||||
|
def test_peft_add_adapter_modules_to_save(self):
|
||||||
|
"""
|
||||||
|
Simple test that tests if `add_adapter` works as expected when training with
|
||||||
|
modules to save.
|
||||||
|
"""
|
||||||
|
from peft import LoraConfig
|
||||||
|
from peft.utils import ModulesToSaveWrapper
|
||||||
|
|
||||||
|
for model_id in self.transformers_test_model_ids:
|
||||||
|
for transformers_class in self.transformers_test_model_classes:
|
||||||
|
dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)
|
||||||
|
|
||||||
|
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||||
|
peft_config = LoraConfig(init_lora_weights=False, modules_to_save=["lm_head"])
|
||||||
|
model.add_adapter(peft_config)
|
||||||
|
self._check_lora_correctly_converted(model)
|
||||||
|
|
||||||
|
_has_modules_to_save_wrapper = False
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, ModulesToSaveWrapper):
|
||||||
|
_has_modules_to_save_wrapper = True
|
||||||
|
self.assertTrue(module.modules_to_save.default.weight.requires_grad)
|
||||||
|
self.assertTrue("lm_head" in name)
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(_has_modules_to_save_wrapper)
|
||||||
|
state_dict = model.get_adapter_state_dict()
|
||||||
|
|
||||||
|
self.assertTrue("lm_head.weight" in state_dict.keys())
|
||||||
|
|
||||||
|
logits = model(dummy_input).logits
|
||||||
|
loss = logits.mean()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
for _, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
self.assertTrue(param.grad is not None)
|
||||||
|
|
||||||
def test_peft_add_adapter_training_gradient_checkpointing(self):
|
def test_peft_add_adapter_training_gradient_checkpointing(self):
|
||||||
"""
|
"""
|
||||||
Simple test that tests if `add_adapter` works as expected when training with
|
Simple test that tests if `add_adapter` works as expected when training with
|
||||||
|
|||||||
Reference in New Issue
Block a user