[Peft] modules_to_save support for peft integration (#27466)
* `modules_to_save` support for peft integration * Update docs/source/en/peft.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * slightly elaborate test --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -182,6 +182,44 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
model_from_pretrained = transformers_class.from_pretrained(tmpdirname).to(torch_device)
|
||||
self.assertTrue(self._check_lora_correctly_converted(model_from_pretrained))
|
||||
|
||||
def test_peft_add_adapter_modules_to_save(self):
|
||||
"""
|
||||
Simple test that tests if `add_adapter` works as expected when training with
|
||||
modules to save.
|
||||
"""
|
||||
from peft import LoraConfig
|
||||
from peft.utils import ModulesToSaveWrapper
|
||||
|
||||
for model_id in self.transformers_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)
|
||||
|
||||
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||
peft_config = LoraConfig(init_lora_weights=False, modules_to_save=["lm_head"])
|
||||
model.add_adapter(peft_config)
|
||||
self._check_lora_correctly_converted(model)
|
||||
|
||||
_has_modules_to_save_wrapper = False
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, ModulesToSaveWrapper):
|
||||
_has_modules_to_save_wrapper = True
|
||||
self.assertTrue(module.modules_to_save.default.weight.requires_grad)
|
||||
self.assertTrue("lm_head" in name)
|
||||
break
|
||||
|
||||
self.assertTrue(_has_modules_to_save_wrapper)
|
||||
state_dict = model.get_adapter_state_dict()
|
||||
|
||||
self.assertTrue("lm_head.weight" in state_dict.keys())
|
||||
|
||||
logits = model(dummy_input).logits
|
||||
loss = logits.mean()
|
||||
loss.backward()
|
||||
|
||||
for _, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
self.assertTrue(param.grad is not None)
|
||||
|
||||
def test_peft_add_adapter_training_gradient_checkpointing(self):
|
||||
"""
|
||||
Simple test that tests if `add_adapter` works as expected when training with
|
||||
|
||||
Reference in New Issue
Block a user