[PEFT] Allow PEFT model dict to be loaded (#25721)
* Allow PEFT model dict to be loaded * make style * make style * Apply suggestions from code review * address comments * fixup * final change * added tests * fix test * better logic for handling if adapter has been loaded * Update tests/peft_integration/test_peft_integration.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
8b13471494
commit
0a55d9f737
@@ -16,6 +16,8 @@ import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import AutoModelForCausalLM, OPTForCausalLM
|
||||
from transformers.testing_utils import require_peft, require_torch, require_torch_gpu, slow, torch_device
|
||||
from transformers.utils import is_torch_available
|
||||
@@ -300,3 +302,33 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
for model_id in self.peft_test_model_ids:
|
||||
pipe = pipeline("text-generation", model_id)
|
||||
_ = pipe("Hello")
|
||||
|
||||
def test_peft_add_adapter_with_state_dict(self):
|
||||
"""
|
||||
Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if
|
||||
add_adapter works as expected with a state_dict being passed.
|
||||
"""
|
||||
from peft import LoraConfig
|
||||
|
||||
dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)
|
||||
|
||||
for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
peft_config = LoraConfig(init_lora_weights=False)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
model.load_adapter(peft_model_id=None)
|
||||
|
||||
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
|
||||
|
||||
dummy_state_dict = torch.load(state_dict_path)
|
||||
|
||||
model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=peft_config)
|
||||
with self.assertRaises(ValueError):
|
||||
model.load_adapter(model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=None))
|
||||
self.assertTrue(self._check_lora_correctly_converted(model))
|
||||
|
||||
# dummy generation
|
||||
_ = model.generate(input_ids=dummy_input)
|
||||
|
||||
Reference in New Issue
Block a user