[PEFT] Add warning for missing key in LoRA adapter (#34068)
When loading a LoRA adapter, so far, there was only a warning when there were unexpected keys in the checkpoint. Now, there is also a warning when there are missing keys. This change is consistent with https://github.com/huggingface/peft/pull/2118 in PEFT and the planned PR https://github.com/huggingface/diffusers/pull/9622 in diffusers. Apart from this change, the error message for unexpected keys was slightly altered for consistency (it should be more readable now). Also, besides adding a test for the missing keys warning, a test for unexpected keys warning was also added, as it was missing so far.
This commit is contained in:
@@ -235,13 +235,29 @@ class PeftAdapterMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if incompatible_keys is not None:
|
if incompatible_keys is not None:
|
||||||
# check only for unexpected keys
|
err_msg = ""
|
||||||
|
origin_name = peft_model_id if peft_model_id is not None else "state_dict"
|
||||||
|
# Check for unexpected keys.
|
||||||
if hasattr(incompatible_keys, "unexpected_keys") and len(incompatible_keys.unexpected_keys) > 0:
|
if hasattr(incompatible_keys, "unexpected_keys") and len(incompatible_keys.unexpected_keys) > 0:
|
||||||
logger.warning(
|
err_msg = (
|
||||||
f"Loading adapter weights from {peft_model_id} led to unexpected keys not found in the model: "
|
f"Loading adapter weights from {origin_name} led to unexpected keys not found in the model: "
|
||||||
f" {incompatible_keys.unexpected_keys}. "
|
f"{', '.join(incompatible_keys.unexpected_keys)}. "
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check for missing keys.
|
||||||
|
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
||||||
|
if missing_keys:
|
||||||
|
# Filter missing keys specific to the current adapter, as missing base model keys are expected.
|
||||||
|
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
||||||
|
if lora_missing_keys:
|
||||||
|
err_msg += (
|
||||||
|
f"Loading adapter weights from {origin_name} led to missing keys in the model: "
|
||||||
|
f"{', '.join(lora_missing_keys)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if err_msg:
|
||||||
|
logger.warning(err_msg)
|
||||||
|
|
||||||
# Re-dispatch model and hooks in case the model is offloaded to CPU / Disk.
|
# Re-dispatch model and hooks in case the model is offloaded to CPU / Disk.
|
||||||
if (
|
if (
|
||||||
(getattr(self, "hf_device_map", None) is not None)
|
(getattr(self, "hf_device_map", None) is not None)
|
||||||
|
|||||||
@@ -20,8 +20,9 @@ import unittest
|
|||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, OPTForCausalLM
|
from transformers import AutoModelForCausalLM, OPTForCausalLM, logging
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
CaptureLogger,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_peft,
|
require_peft,
|
||||||
require_torch,
|
require_torch,
|
||||||
@@ -72,9 +73,15 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
This checks if we pass a remote folder that contains an adapter config and adapter weights, it
|
This checks if we pass a remote folder that contains an adapter config and adapter weights, it
|
||||||
should correctly load a model that has adapters injected on it.
|
should correctly load a model that has adapters injected on it.
|
||||||
"""
|
"""
|
||||||
|
logger = logging.get_logger("transformers.integrations.peft")
|
||||||
|
|
||||||
for model_id in self.peft_test_model_ids:
|
for model_id in self.peft_test_model_ids:
|
||||||
for transformers_class in self.transformers_test_model_classes:
|
for transformers_class in self.transformers_test_model_classes:
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
|
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||||
|
# ensure that under normal circumstances, there are no warnings about keys
|
||||||
|
self.assertNotIn("unexpected keys", cl.out)
|
||||||
|
self.assertNotIn("missing keys", cl.out)
|
||||||
|
|
||||||
self.assertTrue(self._check_lora_correctly_converted(peft_model))
|
self.assertTrue(self._check_lora_correctly_converted(peft_model))
|
||||||
self.assertTrue(peft_model._hf_peft_config_loaded)
|
self.assertTrue(peft_model._hf_peft_config_loaded)
|
||||||
@@ -548,3 +555,70 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
|
|
||||||
model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
|
model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
|
||||||
self.assertTrue(self._check_lora_correctly_converted(model))
|
self.assertTrue(self._check_lora_correctly_converted(model))
|
||||||
|
|
||||||
|
def test_peft_from_pretrained_unexpected_keys_warning(self):
|
||||||
|
"""
|
||||||
|
Test for warning when loading a PEFT checkpoint with unexpected keys.
|
||||||
|
"""
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
|
logger = logging.get_logger("transformers.integrations.peft")
|
||||||
|
|
||||||
|
for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
|
||||||
|
for transformers_class in self.transformers_test_model_classes:
|
||||||
|
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||||
|
|
||||||
|
peft_config = LoraConfig()
|
||||||
|
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
|
||||||
|
dummy_state_dict = torch.load(state_dict_path)
|
||||||
|
|
||||||
|
# add unexpected key
|
||||||
|
dummy_state_dict["foobar"] = next(iter(dummy_state_dict.values()))
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
model.load_adapter(
|
||||||
|
adapter_state_dict=dummy_state_dict, peft_config=peft_config, low_cpu_mem_usage=False
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = "Loading adapter weights from state_dict led to unexpected keys not found in the model: foobar"
|
||||||
|
self.assertIn(msg, cl.out)
|
||||||
|
|
||||||
|
def test_peft_from_pretrained_missing_keys_warning(self):
|
||||||
|
"""
|
||||||
|
Test for warning when loading a PEFT checkpoint with missing keys.
|
||||||
|
"""
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
|
logger = logging.get_logger("transformers.integrations.peft")
|
||||||
|
|
||||||
|
for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
|
||||||
|
for transformers_class in self.transformers_test_model_classes:
|
||||||
|
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||||
|
|
||||||
|
peft_config = LoraConfig()
|
||||||
|
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
|
||||||
|
dummy_state_dict = torch.load(state_dict_path)
|
||||||
|
|
||||||
|
# remove a key so that we have missing keys
|
||||||
|
key = next(iter(dummy_state_dict.keys()))
|
||||||
|
del dummy_state_dict[key]
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
model.load_adapter(
|
||||||
|
adapter_state_dict=dummy_state_dict,
|
||||||
|
peft_config=peft_config,
|
||||||
|
low_cpu_mem_usage=False,
|
||||||
|
adapter_name="other",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Here we need to adjust the key name a bit to account for PEFT-specific naming.
|
||||||
|
# 1. Remove PEFT-specific prefix
|
||||||
|
# If merged after dropping Python 3.8, we can use: key = key.removeprefix(peft_prefix)
|
||||||
|
peft_prefix = "base_model.model."
|
||||||
|
key = key[len(peft_prefix) :]
|
||||||
|
# 2. Insert adapter name
|
||||||
|
prefix, _, suffix = key.rpartition(".")
|
||||||
|
key = f"{prefix}.other.{suffix}"
|
||||||
|
|
||||||
|
msg = f"Loading adapter weights from state_dict led to missing keys in the model: {key}"
|
||||||
|
self.assertIn(msg, cl.out)
|
||||||
|
|||||||
Reference in New Issue
Block a user