[PEFT] Fix PEFT multi adapters support (#26407)

* fix PEFT multi adapters support

* refactor a bit

* save pretrained + BC + added tests

* Update src/transformers/integrations/peft.py

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* add more tests

* add suggestion

* final changes

* adapt a bit

* fixup

* Update src/transformers/integrations/peft.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* adapt from suggestions

---------

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Younes Belkada
2023-09-27 16:45:31 +02:00
committed by GitHub
parent 946bac798c
commit 3ca18d6d09
3 changed files with 76 additions and 11 deletions

View File

@@ -265,9 +265,11 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
_ = model.generate(input_ids=dummy_input)
model.set_adapter("default")
self.assertTrue(model.active_adapters() == ["default"])
self.assertTrue(model.active_adapter() == "default")
model.set_adapter("adapter-2")
self.assertTrue(model.active_adapters() == ["adapter-2"])
self.assertTrue(model.active_adapter() == "adapter-2")
# Logits comparison
@@ -276,6 +278,23 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
)
self.assertFalse(torch.allclose(logits_original_model, logits_adapter_2.logits, atol=1e-6, rtol=1e-6))
model.set_adapter(["adapter-2", "default"])
self.assertTrue(model.active_adapters() == ["adapter-2", "default"])
self.assertTrue(model.active_adapter() == "adapter-2")
logits_adapter_mixed = model(dummy_input)
self.assertFalse(
torch.allclose(logits_adapter_1.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
)
self.assertFalse(
torch.allclose(logits_adapter_2.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
)
# multi active adapter saving not supported
with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
@require_torch_gpu
def test_peft_from_pretrained_kwargs(self):
"""