added logic for deleting adapters once loaded (#34650)
* added logic for deleting adapters once loaded * updated to the latest version of transformers, merged utility function into the source * updated with missing check * added peft version check * Apply suggestions from code review Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> * changes according to reviewer * added test for deleting adapter(s) * styling changes * styling changes in test * removed redundant code * formatted my contributions with ruff * optimized error handling * ruff formatted with correct config * resolved formatting issues --------- Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
1650e0e514
commit
ca00950057
@@ -350,7 +350,6 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
self.assertFalse(
|
||||
torch.allclose(logits_adapter_1.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
|
||||
)
|
||||
|
||||
self.assertFalse(
|
||||
torch.allclose(logits_adapter_2.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
|
||||
)
|
||||
@@ -359,6 +358,70 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
def test_delete_adapter(self):
|
||||
"""
|
||||
Enhanced test for `delete_adapter` to handle multiple adapters,
|
||||
edge cases, and proper error handling.
|
||||
"""
|
||||
from peft import LoraConfig
|
||||
|
||||
for model_id in self.transformers_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
# Add multiple adapters
|
||||
peft_config_1 = LoraConfig(init_lora_weights=False)
|
||||
peft_config_2 = LoraConfig(init_lora_weights=False)
|
||||
model.add_adapter(peft_config_1, adapter_name="adapter_1")
|
||||
model.add_adapter(peft_config_2, adapter_name="adapter_2")
|
||||
|
||||
# Ensure adapters were added
|
||||
self.assertIn("adapter_1", model.peft_config)
|
||||
self.assertIn("adapter_2", model.peft_config)
|
||||
|
||||
# Delete a single adapter
|
||||
model.delete_adapter("adapter_1")
|
||||
self.assertNotIn("adapter_1", model.peft_config)
|
||||
self.assertIn("adapter_2", model.peft_config)
|
||||
|
||||
# Delete remaining adapter
|
||||
model.delete_adapter("adapter_2")
|
||||
self.assertNotIn("adapter_2", model.peft_config)
|
||||
self.assertFalse(model._hf_peft_config_loaded)
|
||||
|
||||
# Re-add adapters for edge case tests
|
||||
model.add_adapter(peft_config_1, adapter_name="adapter_1")
|
||||
model.add_adapter(peft_config_2, adapter_name="adapter_2")
|
||||
|
||||
# Attempt to delete multiple adapters at once
|
||||
model.delete_adapter(["adapter_1", "adapter_2"])
|
||||
self.assertNotIn("adapter_1", model.peft_config)
|
||||
self.assertNotIn("adapter_2", model.peft_config)
|
||||
self.assertFalse(model._hf_peft_config_loaded)
|
||||
|
||||
# Test edge cases
|
||||
with self.assertRaisesRegex(ValueError, "The following adapter\\(s\\) are not present"):
|
||||
model.delete_adapter("nonexistent_adapter")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "The following adapter\\(s\\) are not present"):
|
||||
model.delete_adapter(["adapter_1", "nonexistent_adapter"])
|
||||
|
||||
# Deleting with an empty list or None should not raise errors
|
||||
model.add_adapter(peft_config_1, adapter_name="adapter_1")
|
||||
model.add_adapter(peft_config_2, adapter_name="adapter_2")
|
||||
model.delete_adapter([]) # No-op
|
||||
self.assertIn("adapter_1", model.peft_config)
|
||||
self.assertIn("adapter_2", model.peft_config)
|
||||
|
||||
model.delete_adapter(None) # No-op
|
||||
self.assertIn("adapter_1", model.peft_config)
|
||||
self.assertIn("adapter_2", model.peft_config)
|
||||
|
||||
# Deleting duplicate adapter names in the list
|
||||
model.delete_adapter(["adapter_1", "adapter_1"])
|
||||
self.assertNotIn("adapter_1", model.peft_config)
|
||||
self.assertIn("adapter_2", model.peft_config)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_bitsandbytes
|
||||
def test_peft_from_pretrained_kwargs(self):
|
||||
|
||||
Reference in New Issue
Block a user