[PEFT] Final fixes (#26559)

* fix issues with PEFT

* logger warning futurewarning issues

* fixup

* adapt from suggestions

* oops

* rm test
This commit is contained in:
Younes Belkada
2023-10-03 14:53:09 +02:00
committed by GitHub
parent ae9a344cce
commit 2aef9a9601
4 changed files with 54 additions and 16 deletions

View File

@@ -312,6 +312,42 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
# dummy generation
_ = peft_model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device))
@require_torch_gpu
def test_peft_save_quantized(self):
"""
Simple test that tests the basic usage of PEFT model save_pretrained with quantized base models
"""
# 4bit
for model_id in self.peft_test_model_ids:
for transformers_class in self.transformers_test_model_classes:
peft_model = transformers_class.from_pretrained(model_id, load_in_4bit=True, device_map="auto")
module = peft_model.model.decoder.layers[0].self_attn.v_proj
self.assertTrue(module.__class__.__name__ == "Linear4bit")
self.assertTrue(peft_model.hf_device_map is not None)
with tempfile.TemporaryDirectory() as tmpdirname:
peft_model.save_pretrained(tmpdirname)
self.assertTrue("adapter_model.bin" in os.listdir(tmpdirname))
self.assertTrue("adapter_config.json" in os.listdir(tmpdirname))
self.assertTrue("pytorch_model.bin" not in os.listdir(tmpdirname))
# 8-bit
for model_id in self.peft_test_model_ids:
for transformers_class in self.transformers_test_model_classes:
peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
module = peft_model.model.decoder.layers[0].self_attn.v_proj
self.assertTrue(module.__class__.__name__ == "Linear8bitLt")
self.assertTrue(peft_model.hf_device_map is not None)
with tempfile.TemporaryDirectory() as tmpdirname:
peft_model.save_pretrained(tmpdirname)
self.assertTrue("adapter_model.bin" in os.listdir(tmpdirname))
self.assertTrue("adapter_config.json" in os.listdir(tmpdirname))
self.assertTrue("pytorch_model.bin" not in os.listdir(tmpdirname))
def test_peft_pipeline(self):
"""
Simple test that tests the basic usage of PEFT model + pipeline