From f4b674f2690b90f6c2e278fb8f612a413a68934b Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 28 Nov 2024 13:56:25 +0100 Subject: [PATCH] [PEFT] Set eval mode when loading PEFT adapter (#34509) * [PEFT] Set eval mode when loading PEFT adapter Resolves #34469 When calling model.load_adapter to load a PEFT adapter, by default the adapter should be set to eval mode. This is now correctly done. Users can still pass is_trainable=True to load the adapter in training mode. * Linter --- src/transformers/integrations/peft.py | 8 ++++ .../peft_integration/test_peft_integration.py | 43 +++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index b3352be0f9..ef09281431 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -81,6 +81,7 @@ class PeftAdapterMixin: peft_config: Dict[str, Any] = None, adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None, low_cpu_mem_usage: bool = False, + is_trainable: bool = False, adapter_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """ @@ -136,6 +137,9 @@ class PeftAdapterMixin: low_cpu_mem_usage (`bool`, *optional*, defaults to `False`): Reduce memory usage while loading the PEFT adapter. This should also speed up the loading process. Requires PEFT version 0.13.0 or higher. + is_trainable (`bool`, *optional*, defaults to `False`): + Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be + used for inference. adapter_kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and `find_adapter_config_file` method. @@ -209,6 +213,7 @@ class PeftAdapterMixin: token=token, **adapter_kwargs, ) + peft_config.inference_mode = not is_trainable # Create and add fresh new adapters into the model. inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs) @@ -258,6 +263,9 @@ class PeftAdapterMixin: if err_msg: logger.warning(err_msg) + if peft_config.inference_mode: + self.eval() + # Re-dispatch model and hooks in case the model is offloaded to CPU / Disk. if ( (getattr(self, "hf_device_map", None) is not None) diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index aebf2b2952..48fd6da3d6 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -622,3 +622,46 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): msg = f"Loading adapter weights from state_dict led to missing keys in the model: {key}" self.assertIn(msg, cl.out) + + def test_peft_load_adapter_training_inference_mode_true(self): + """ + By default, when loading an adapter, the whole model should be in eval mode and no parameter should have + requires_grad=False. + """ + for model_id in self.peft_test_model_ids: + for transformers_class in self.transformers_test_model_classes: + peft_model = transformers_class.from_pretrained(model_id).to(torch_device) + + with tempfile.TemporaryDirectory() as tmpdirname: + peft_model.save_pretrained(tmpdirname) + model = transformers_class.from_pretrained(peft_model.config._name_or_path) + model.load_adapter(tmpdirname) + assert not any(p.requires_grad for p in model.parameters()) + assert not any(m.training for m in model.modules()) + del model + + def test_peft_load_adapter_training_inference_mode_false(self): + """ + When passing is_trainable=True, the LoRA modules should be in training mode and their parameters should have + requires_grad=True. + """ + for model_id in self.peft_test_model_ids: + for transformers_class in self.transformers_test_model_classes: + peft_model = transformers_class.from_pretrained(model_id).to(torch_device) + + with tempfile.TemporaryDirectory() as tmpdirname: + peft_model.save_pretrained(tmpdirname) + model = transformers_class.from_pretrained(peft_model.config._name_or_path) + model.load_adapter(tmpdirname, is_trainable=True) + + for name, module in model.named_modules(): + if len(list(module.children())): + # only check leaf modules + continue + + if "lora_" in name: + assert module.training + assert all(p.requires_grad for p in module.parameters()) + else: + assert not module.training + assert all(not p.requires_grad for p in module.parameters())