[PEFT] Set eval mode when loading PEFT adapter (#34509)
* [PEFT] Set eval mode when loading PEFT adapter Resolves #34469 When calling model.load_adapter to load a PEFT adapter, by default the adapter should be set to eval mode. This is now correctly done. Users can still pass is_trainable=True to load the adapter in training mode. * Linter
This commit is contained in:
@@ -622,3 +622,46 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
|
||||
msg = f"Loading adapter weights from state_dict led to missing keys in the model: {key}"
|
||||
self.assertIn(msg, cl.out)
|
||||
|
||||
def test_peft_load_adapter_training_inference_mode_true(self):
|
||||
"""
|
||||
By default, when loading an adapter, the whole model should be in eval mode and no parameter should have
|
||||
requires_grad=False.
|
||||
"""
|
||||
for model_id in self.peft_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
peft_model.save_pretrained(tmpdirname)
|
||||
model = transformers_class.from_pretrained(peft_model.config._name_or_path)
|
||||
model.load_adapter(tmpdirname)
|
||||
assert not any(p.requires_grad for p in model.parameters())
|
||||
assert not any(m.training for m in model.modules())
|
||||
del model
|
||||
|
||||
def test_peft_load_adapter_training_inference_mode_false(self):
|
||||
"""
|
||||
When passing is_trainable=True, the LoRA modules should be in training mode and their parameters should have
|
||||
requires_grad=True.
|
||||
"""
|
||||
for model_id in self.peft_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
peft_model.save_pretrained(tmpdirname)
|
||||
model = transformers_class.from_pretrained(peft_model.config._name_or_path)
|
||||
model.load_adapter(tmpdirname, is_trainable=True)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if len(list(module.children())):
|
||||
# only check leaf modules
|
||||
continue
|
||||
|
||||
if "lora_" in name:
|
||||
assert module.training
|
||||
assert all(p.requires_grad for p in module.parameters())
|
||||
else:
|
||||
assert not module.training
|
||||
assert all(not p.requires_grad for p in module.parameters())
|
||||
|
||||
Reference in New Issue
Block a user