added logic for deleting adapters once loaded (#34650)
* added logic for deleting adapters once loaded * updated to the latest version of transformers, merged utility function into the source * updated with missing check * added peft version check * Apply suggestions from code review Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> * changes according to reviewer * added test for deleting adapter(s) * styling changes * styling changes in test * removed redundant code * formatted my contributions with ruff * optimized error handling * ruff formatted with correct config * resolved formatting issues --------- Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
1650e0e514
commit
ca00950057
@@ -11,6 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
@@ -525,3 +526,64 @@ class PeftAdapterMixin:
|
|||||||
offload_dir=offload_folder,
|
offload_dir=offload_folder,
|
||||||
**dispatch_model_kwargs,
|
**dispatch_model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def delete_adapter(self, adapter_names: Union[List[str], str]) -> None:
|
||||||
|
"""
|
||||||
|
Delete an adapter's LoRA layers from the underlying model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
adapter_names (`Union[List[str], str]`):
|
||||||
|
The name(s) of the adapter(s) to delete.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
from diffusers import AutoPipelineForText2Image
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||||
|
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||||
|
).to("cuda")
|
||||||
|
pipeline.load_lora_weights(
|
||||||
|
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
|
||||||
|
)
|
||||||
|
pipeline.delete_adapters("cinematic")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
check_peft_version(min_version=MIN_PEFT_VERSION)
|
||||||
|
|
||||||
|
if not self._hf_peft_config_loaded:
|
||||||
|
raise ValueError("No adapter loaded. Please load an adapter first.")
|
||||||
|
|
||||||
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||||
|
|
||||||
|
if isinstance(adapter_names, str):
|
||||||
|
adapter_names = [adapter_names]
|
||||||
|
|
||||||
|
# Check that all adapter names are present in the config
|
||||||
|
missing_adapters = [name for name in adapter_names if name not in self.peft_config]
|
||||||
|
if missing_adapters:
|
||||||
|
raise ValueError(
|
||||||
|
f"The following adapter(s) are not present and cannot be deleted: {', '.join(missing_adapters)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for adapter_name in adapter_names:
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, BaseTunerLayer):
|
||||||
|
if hasattr(module, "delete_adapter"):
|
||||||
|
module.delete_adapter(adapter_name)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# For transformers integration - we need to pop the adapter from the config
|
||||||
|
if getattr(self, "_hf_peft_config_loaded", False) and hasattr(self, "peft_config"):
|
||||||
|
self.peft_config.pop(adapter_name, None)
|
||||||
|
|
||||||
|
# In case all adapters are deleted, we need to delete the config
|
||||||
|
# and make sure to set the flag to False
|
||||||
|
if len(self.peft_config) == 0:
|
||||||
|
del self.peft_config
|
||||||
|
self._hf_peft_config_loaded = False
|
||||||
|
|||||||
@@ -350,7 +350,6 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
torch.allclose(logits_adapter_1.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
|
torch.allclose(logits_adapter_1.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
torch.allclose(logits_adapter_2.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
|
torch.allclose(logits_adapter_2.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
|
||||||
)
|
)
|
||||||
@@ -359,6 +358,70 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname:
|
with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
def test_delete_adapter(self):
|
||||||
|
"""
|
||||||
|
Enhanced test for `delete_adapter` to handle multiple adapters,
|
||||||
|
edge cases, and proper error handling.
|
||||||
|
"""
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
|
for model_id in self.transformers_test_model_ids:
|
||||||
|
for transformers_class in self.transformers_test_model_classes:
|
||||||
|
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||||
|
|
||||||
|
# Add multiple adapters
|
||||||
|
peft_config_1 = LoraConfig(init_lora_weights=False)
|
||||||
|
peft_config_2 = LoraConfig(init_lora_weights=False)
|
||||||
|
model.add_adapter(peft_config_1, adapter_name="adapter_1")
|
||||||
|
model.add_adapter(peft_config_2, adapter_name="adapter_2")
|
||||||
|
|
||||||
|
# Ensure adapters were added
|
||||||
|
self.assertIn("adapter_1", model.peft_config)
|
||||||
|
self.assertIn("adapter_2", model.peft_config)
|
||||||
|
|
||||||
|
# Delete a single adapter
|
||||||
|
model.delete_adapter("adapter_1")
|
||||||
|
self.assertNotIn("adapter_1", model.peft_config)
|
||||||
|
self.assertIn("adapter_2", model.peft_config)
|
||||||
|
|
||||||
|
# Delete remaining adapter
|
||||||
|
model.delete_adapter("adapter_2")
|
||||||
|
self.assertNotIn("adapter_2", model.peft_config)
|
||||||
|
self.assertFalse(model._hf_peft_config_loaded)
|
||||||
|
|
||||||
|
# Re-add adapters for edge case tests
|
||||||
|
model.add_adapter(peft_config_1, adapter_name="adapter_1")
|
||||||
|
model.add_adapter(peft_config_2, adapter_name="adapter_2")
|
||||||
|
|
||||||
|
# Attempt to delete multiple adapters at once
|
||||||
|
model.delete_adapter(["adapter_1", "adapter_2"])
|
||||||
|
self.assertNotIn("adapter_1", model.peft_config)
|
||||||
|
self.assertNotIn("adapter_2", model.peft_config)
|
||||||
|
self.assertFalse(model._hf_peft_config_loaded)
|
||||||
|
|
||||||
|
# Test edge cases
|
||||||
|
with self.assertRaisesRegex(ValueError, "The following adapter\\(s\\) are not present"):
|
||||||
|
model.delete_adapter("nonexistent_adapter")
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "The following adapter\\(s\\) are not present"):
|
||||||
|
model.delete_adapter(["adapter_1", "nonexistent_adapter"])
|
||||||
|
|
||||||
|
# Deleting with an empty list or None should not raise errors
|
||||||
|
model.add_adapter(peft_config_1, adapter_name="adapter_1")
|
||||||
|
model.add_adapter(peft_config_2, adapter_name="adapter_2")
|
||||||
|
model.delete_adapter([]) # No-op
|
||||||
|
self.assertIn("adapter_1", model.peft_config)
|
||||||
|
self.assertIn("adapter_2", model.peft_config)
|
||||||
|
|
||||||
|
model.delete_adapter(None) # No-op
|
||||||
|
self.assertIn("adapter_1", model.peft_config)
|
||||||
|
self.assertIn("adapter_2", model.peft_config)
|
||||||
|
|
||||||
|
# Deleting duplicate adapter names in the list
|
||||||
|
model.delete_adapter(["adapter_1", "adapter_1"])
|
||||||
|
self.assertNotIn("adapter_1", model.peft_config)
|
||||||
|
self.assertIn("adapter_2", model.peft_config)
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
def test_peft_from_pretrained_kwargs(self):
|
def test_peft_from_pretrained_kwargs(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user