[PEFT] Set eval mode when loading PEFT adapter (#34509)

* [PEFT] Set eval mode when loading PEFT adapter

Resolves #34469

When calling model.load_adapter to load a PEFT adapter, by default the
adapter should be set to eval mode. This is now correctly done. Users
can still pass is_trainable=True to load the adapter in training mode.

* Linter
This commit is contained in:
Benjamin Bossan
2024-11-28 13:56:25 +01:00
committed by GitHub
parent 5523e38b55
commit f4b674f269
2 changed files with 51 additions and 0 deletions

View File

@@ -81,6 +81,7 @@ class PeftAdapterMixin:
peft_config: Dict[str, Any] = None, peft_config: Dict[str, Any] = None,
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None, adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
low_cpu_mem_usage: bool = False, low_cpu_mem_usage: bool = False,
is_trainable: bool = False,
adapter_kwargs: Optional[Dict[str, Any]] = None, adapter_kwargs: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
""" """
@@ -136,6 +137,9 @@ class PeftAdapterMixin:
low_cpu_mem_usage (`bool`, *optional*, defaults to `False`): low_cpu_mem_usage (`bool`, *optional*, defaults to `False`):
Reduce memory usage while loading the PEFT adapter. This should also speed up the loading process. Reduce memory usage while loading the PEFT adapter. This should also speed up the loading process.
Requires PEFT version 0.13.0 or higher. Requires PEFT version 0.13.0 or higher.
is_trainable (`bool`, *optional*, defaults to `False`):
Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be
used for inference.
adapter_kwargs (`Dict[str, Any]`, *optional*): adapter_kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and
`find_adapter_config_file` method. `find_adapter_config_file` method.
@@ -209,6 +213,7 @@ class PeftAdapterMixin:
token=token, token=token,
**adapter_kwargs, **adapter_kwargs,
) )
peft_config.inference_mode = not is_trainable
# Create and add fresh new adapters into the model. # Create and add fresh new adapters into the model.
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs) inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)
@@ -258,6 +263,9 @@ class PeftAdapterMixin:
if err_msg: if err_msg:
logger.warning(err_msg) logger.warning(err_msg)
if peft_config.inference_mode:
self.eval()
# Re-dispatch model and hooks in case the model is offloaded to CPU / Disk. # Re-dispatch model and hooks in case the model is offloaded to CPU / Disk.
if ( if (
(getattr(self, "hf_device_map", None) is not None) (getattr(self, "hf_device_map", None) is not None)

View File

@@ -622,3 +622,46 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
msg = f"Loading adapter weights from state_dict led to missing keys in the model: {key}" msg = f"Loading adapter weights from state_dict led to missing keys in the model: {key}"
self.assertIn(msg, cl.out) self.assertIn(msg, cl.out)
def test_peft_load_adapter_training_inference_mode_true(self):
"""
By default, when loading an adapter, the whole model should be in eval mode and no parameter should have
requires_grad=False.
"""
for model_id in self.peft_test_model_ids:
for transformers_class in self.transformers_test_model_classes:
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
with tempfile.TemporaryDirectory() as tmpdirname:
peft_model.save_pretrained(tmpdirname)
model = transformers_class.from_pretrained(peft_model.config._name_or_path)
model.load_adapter(tmpdirname)
assert not any(p.requires_grad for p in model.parameters())
assert not any(m.training for m in model.modules())
del model
def test_peft_load_adapter_training_inference_mode_false(self):
"""
When passing is_trainable=True, the LoRA modules should be in training mode and their parameters should have
requires_grad=True.
"""
for model_id in self.peft_test_model_ids:
for transformers_class in self.transformers_test_model_classes:
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
with tempfile.TemporaryDirectory() as tmpdirname:
peft_model.save_pretrained(tmpdirname)
model = transformers_class.from_pretrained(peft_model.config._name_or_path)
model.load_adapter(tmpdirname, is_trainable=True)
for name, module in model.named_modules():
if len(list(module.children())):
# only check leaf modules
continue
if "lora_" in name:
assert module.training
assert all(p.requires_grad for p in module.parameters())
else:
assert not module.training
assert all(not p.requires_grad for p in module.parameters())