[PEFT] Set eval mode when loading PEFT adapter (#34509)
* [PEFT] Set eval mode when loading PEFT adapter Resolves #34469 When calling model.load_adapter to load a PEFT adapter, by default the adapter should be set to eval mode. This is now correctly done. Users can still pass is_trainable=True to load the adapter in training mode. * Linter
This commit is contained in:
@@ -81,6 +81,7 @@ class PeftAdapterMixin:
|
||||
peft_config: Dict[str, Any] = None,
|
||||
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
is_trainable: bool = False,
|
||||
adapter_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -136,6 +137,9 @@ class PeftAdapterMixin:
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `False`):
|
||||
Reduce memory usage while loading the PEFT adapter. This should also speed up the loading process.
|
||||
Requires PEFT version 0.13.0 or higher.
|
||||
is_trainable (`bool`, *optional*, defaults to `False`):
|
||||
Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be
|
||||
used for inference.
|
||||
adapter_kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and
|
||||
`find_adapter_config_file` method.
|
||||
@@ -209,6 +213,7 @@ class PeftAdapterMixin:
|
||||
token=token,
|
||||
**adapter_kwargs,
|
||||
)
|
||||
peft_config.inference_mode = not is_trainable
|
||||
|
||||
# Create and add fresh new adapters into the model.
|
||||
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)
|
||||
@@ -258,6 +263,9 @@ class PeftAdapterMixin:
|
||||
if err_msg:
|
||||
logger.warning(err_msg)
|
||||
|
||||
if peft_config.inference_mode:
|
||||
self.eval()
|
||||
|
||||
# Re-dispatch model and hooks in case the model is offloaded to CPU / Disk.
|
||||
if (
|
||||
(getattr(self, "hf_device_map", None) is not None)
|
||||
|
||||
@@ -622,3 +622,46 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
|
||||
msg = f"Loading adapter weights from state_dict led to missing keys in the model: {key}"
|
||||
self.assertIn(msg, cl.out)
|
||||
|
||||
def test_peft_load_adapter_training_inference_mode_true(self):
|
||||
"""
|
||||
By default, when loading an adapter, the whole model should be in eval mode and no parameter should have
|
||||
requires_grad=False.
|
||||
"""
|
||||
for model_id in self.peft_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
peft_model.save_pretrained(tmpdirname)
|
||||
model = transformers_class.from_pretrained(peft_model.config._name_or_path)
|
||||
model.load_adapter(tmpdirname)
|
||||
assert not any(p.requires_grad for p in model.parameters())
|
||||
assert not any(m.training for m in model.modules())
|
||||
del model
|
||||
|
||||
def test_peft_load_adapter_training_inference_mode_false(self):
|
||||
"""
|
||||
When passing is_trainable=True, the LoRA modules should be in training mode and their parameters should have
|
||||
requires_grad=True.
|
||||
"""
|
||||
for model_id in self.peft_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
peft_model.save_pretrained(tmpdirname)
|
||||
model = transformers_class.from_pretrained(peft_model.config._name_or_path)
|
||||
model.load_adapter(tmpdirname, is_trainable=True)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if len(list(module.children())):
|
||||
# only check leaf modules
|
||||
continue
|
||||
|
||||
if "lora_" in name:
|
||||
assert module.training
|
||||
assert all(p.requires_grad for p in module.parameters())
|
||||
else:
|
||||
assert not module.training
|
||||
assert all(not p.requires_grad for p in module.parameters())
|
||||
|
||||
Reference in New Issue
Block a user