[PEFT] Final fixes (#26559)
* fix issues with PEFT * logger warning futurewarning issues * fixup * adapt from suggestions * oops * rm test
This commit is contained in:
@@ -312,6 +312,42 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
# dummy generation
|
||||
_ = peft_model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device))
|
||||
|
||||
@require_torch_gpu
|
||||
def test_peft_save_quantized(self):
|
||||
"""
|
||||
Simple test that tests the basic usage of PEFT model save_pretrained with quantized base models
|
||||
"""
|
||||
# 4bit
|
||||
for model_id in self.peft_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
peft_model = transformers_class.from_pretrained(model_id, load_in_4bit=True, device_map="auto")
|
||||
|
||||
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
||||
self.assertTrue(module.__class__.__name__ == "Linear4bit")
|
||||
self.assertTrue(peft_model.hf_device_map is not None)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
peft_model.save_pretrained(tmpdirname)
|
||||
self.assertTrue("adapter_model.bin" in os.listdir(tmpdirname))
|
||||
self.assertTrue("adapter_config.json" in os.listdir(tmpdirname))
|
||||
self.assertTrue("pytorch_model.bin" not in os.listdir(tmpdirname))
|
||||
|
||||
# 8-bit
|
||||
for model_id in self.peft_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
|
||||
|
||||
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
||||
self.assertTrue(module.__class__.__name__ == "Linear8bitLt")
|
||||
self.assertTrue(peft_model.hf_device_map is not None)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
peft_model.save_pretrained(tmpdirname)
|
||||
|
||||
self.assertTrue("adapter_model.bin" in os.listdir(tmpdirname))
|
||||
self.assertTrue("adapter_config.json" in os.listdir(tmpdirname))
|
||||
self.assertTrue("pytorch_model.bin" not in os.listdir(tmpdirname))
|
||||
|
||||
def test_peft_pipeline(self):
|
||||
"""
|
||||
Simple test that tests the basic usage of PEFT model + pipeline
|
||||
|
||||
@@ -263,13 +263,6 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
|
||||
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_warns_save_pretrained(self):
|
||||
r"""
|
||||
Test whether trying to save a model after converting it in 8-bit will throw a warning.
|
||||
"""
|
||||
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.model_8bit.save_pretrained(tmpdirname)
|
||||
|
||||
def test_raise_if_config_and_load_in_8bit(self):
|
||||
r"""
|
||||
Test that loading the model with the config and `load_in_8bit` raises an error
|
||||
|
||||
Reference in New Issue
Block a user