[PEFT] Set eval mode when loading PEFT adapter (#34509)

* [PEFT] Set eval mode when loading PEFT adapter

Resolves #34469

When calling model.load_adapter to load a PEFT adapter, by default the
adapter should be set to eval mode. This is now correctly done. Users
can still pass is_trainable=True to load the adapter in training mode.

* Linter
This commit is contained in:
Benjamin Bossan
2024-11-28 13:56:25 +01:00
committed by GitHub
parent 5523e38b55
commit f4b674f269
2 changed files with 51 additions and 0 deletions

View File

@@ -81,6 +81,7 @@ class PeftAdapterMixin:
peft_config: Dict[str, Any] = None,
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
low_cpu_mem_usage: bool = False,
is_trainable: bool = False,
adapter_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""
@@ -136,6 +137,9 @@ class PeftAdapterMixin:
low_cpu_mem_usage (`bool`, *optional*, defaults to `False`):
Reduce memory usage while loading the PEFT adapter. This should also speed up the loading process.
Requires PEFT version 0.13.0 or higher.
is_trainable (`bool`, *optional*, defaults to `False`):
Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be
used for inference.
adapter_kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and
`find_adapter_config_file` method.
@@ -209,6 +213,7 @@ class PeftAdapterMixin:
token=token,
**adapter_kwargs,
)
peft_config.inference_mode = not is_trainable
# Create and add fresh new adapters into the model.
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)
@@ -258,6 +263,9 @@ class PeftAdapterMixin:
if err_msg:
logger.warning(err_msg)
if peft_config.inference_mode:
self.eval()
# Re-dispatch model and hooks in case the model is offloaded to CPU / Disk.
if (
(getattr(self, "hf_device_map", None) is not None)

View File

@@ -622,3 +622,46 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
msg = f"Loading adapter weights from state_dict led to missing keys in the model: {key}"
self.assertIn(msg, cl.out)
def test_peft_load_adapter_training_inference_mode_true(self):
"""
By default, when loading an adapter, the whole model should be in eval mode and no parameter should have
requires_grad=False.
"""
for model_id in self.peft_test_model_ids:
for transformers_class in self.transformers_test_model_classes:
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
with tempfile.TemporaryDirectory() as tmpdirname:
peft_model.save_pretrained(tmpdirname)
model = transformers_class.from_pretrained(peft_model.config._name_or_path)
model.load_adapter(tmpdirname)
assert not any(p.requires_grad for p in model.parameters())
assert not any(m.training for m in model.modules())
del model
def test_peft_load_adapter_training_inference_mode_false(self):
"""
When passing is_trainable=True, the LoRA modules should be in training mode and their parameters should have
requires_grad=True.
"""
for model_id in self.peft_test_model_ids:
for transformers_class in self.transformers_test_model_classes:
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
with tempfile.TemporaryDirectory() as tmpdirname:
peft_model.save_pretrained(tmpdirname)
model = transformers_class.from_pretrained(peft_model.config._name_or_path)
model.load_adapter(tmpdirname, is_trainable=True)
for name, module in model.named_modules():
if len(list(module.children())):
# only check leaf modules
continue
if "lora_" in name:
assert module.training
assert all(p.requires_grad for p in module.parameters())
else:
assert not module.training
assert all(not p.requires_grad for p in module.parameters())