[PEFT] Fix PEFT multi adapters support (#26407)
* fix PEFT multi adapters support * refactor a bit * save pretrained + BC + added tests * Update src/transformers/integrations/peft.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * add more tests * add suggestion * final changes * adapt a bit * fixup * Update src/transformers/integrations/peft.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * adapt from suggestions --------- Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -265,9 +265,11 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
_ = model.generate(input_ids=dummy_input)
|
||||
|
||||
model.set_adapter("default")
|
||||
self.assertTrue(model.active_adapters() == ["default"])
|
||||
self.assertTrue(model.active_adapter() == "default")
|
||||
|
||||
model.set_adapter("adapter-2")
|
||||
self.assertTrue(model.active_adapters() == ["adapter-2"])
|
||||
self.assertTrue(model.active_adapter() == "adapter-2")
|
||||
|
||||
# Logits comparison
|
||||
@@ -276,6 +278,23 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
)
|
||||
self.assertFalse(torch.allclose(logits_original_model, logits_adapter_2.logits, atol=1e-6, rtol=1e-6))
|
||||
|
||||
model.set_adapter(["adapter-2", "default"])
|
||||
self.assertTrue(model.active_adapters() == ["adapter-2", "default"])
|
||||
self.assertTrue(model.active_adapter() == "adapter-2")
|
||||
|
||||
logits_adapter_mixed = model(dummy_input)
|
||||
self.assertFalse(
|
||||
torch.allclose(logits_adapter_1.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
|
||||
)
|
||||
|
||||
self.assertFalse(
|
||||
torch.allclose(logits_adapter_2.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
|
||||
)
|
||||
|
||||
# multi active adapter saving not supported
|
||||
with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_peft_from_pretrained_kwargs(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user