[PEFT] Fix PEFT + gradient checkpointing (#25846)
* fix PEFT + gradient checkpointing * add disable RG * polish tests * fix comment * Revert "fix comment" This reverts commit b85386f50d2b104bac522e823c47b7e232116a47. * final explanations and tests
This commit is contained in:
@@ -1750,6 +1750,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
||||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||||
|
|
||||||
|
if getattr(self, "_hf_peft_config_loaded", False):
|
||||||
|
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
|
||||||
|
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
|
||||||
|
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
|
||||||
|
# the gradients to make sure the gradient flows.
|
||||||
|
self.enable_input_require_grads()
|
||||||
|
|
||||||
def gradient_checkpointing_disable(self):
|
def gradient_checkpointing_disable(self):
|
||||||
"""
|
"""
|
||||||
Deactivates gradient checkpointing for the current model.
|
Deactivates gradient checkpointing for the current model.
|
||||||
@@ -1760,6 +1767,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if self.supports_gradient_checkpointing:
|
if self.supports_gradient_checkpointing:
|
||||||
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
||||||
|
|
||||||
|
if getattr(self, "_hf_peft_config_loaded", False):
|
||||||
|
self.disable_input_require_grads()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_gradient_checkpointing(self) -> bool:
|
def is_gradient_checkpointing(self) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -179,6 +179,52 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
model_from_pretrained = transformers_class.from_pretrained(tmpdirname).to(torch_device)
|
model_from_pretrained = transformers_class.from_pretrained(tmpdirname).to(torch_device)
|
||||||
self.assertTrue(self._check_lora_correctly_converted(model_from_pretrained))
|
self.assertTrue(self._check_lora_correctly_converted(model_from_pretrained))
|
||||||
|
|
||||||
|
def test_peft_add_adapter_training_gradient_checkpointing(self):
|
||||||
|
"""
|
||||||
|
Simple test that tests if `add_adapter` works as expected when training with
|
||||||
|
gradient checkpointing.
|
||||||
|
"""
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
|
for model_id in self.transformers_test_model_ids:
|
||||||
|
for transformers_class in self.transformers_test_model_classes:
|
||||||
|
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||||
|
|
||||||
|
peft_config = LoraConfig(init_lora_weights=False)
|
||||||
|
|
||||||
|
model.add_adapter(peft_config)
|
||||||
|
|
||||||
|
self.assertTrue(self._check_lora_correctly_converted(model))
|
||||||
|
|
||||||
|
# When attaching adapters the input embeddings will stay frozen, this will
|
||||||
|
# lead to the output embedding having requires_grad=False.
|
||||||
|
dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)
|
||||||
|
frozen_output = model.get_input_embeddings()(dummy_input)
|
||||||
|
self.assertTrue(frozen_output.requires_grad is False)
|
||||||
|
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
# Since here we attached the hook, the input should have requires_grad to set
|
||||||
|
# properly
|
||||||
|
non_frozen_output = model.get_input_embeddings()(dummy_input)
|
||||||
|
self.assertTrue(non_frozen_output.requires_grad is True)
|
||||||
|
|
||||||
|
# To repro the Trainer issue
|
||||||
|
dummy_input.requires_grad = False
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if "lora" in name.lower():
|
||||||
|
self.assertTrue(param.requires_grad)
|
||||||
|
|
||||||
|
logits = model(dummy_input).logits
|
||||||
|
loss = logits.mean()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
self.assertTrue("lora" in name.lower())
|
||||||
|
self.assertTrue(param.grad is not None)
|
||||||
|
|
||||||
def test_peft_add_multi_adapter(self):
|
def test_peft_add_multi_adapter(self):
|
||||||
"""
|
"""
|
||||||
Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if
|
Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if
|
||||||
|
|||||||
Reference in New Issue
Block a user