diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index 523137d53a..7efa5252e8 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -25,6 +25,7 @@ from transformers import ( AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, + BitsAndBytesConfig, OPTForCausalLM, Trainer, TrainingArguments, @@ -76,6 +77,12 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): return is_peft_loaded + def _get_bnb_4bit_config(self): + return BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4") + + def _get_bnb_8bit_config(self): + return BitsAndBytesConfig(load_in_8bit=True) + def test_peft_from_pretrained(self): """ Simple test that tests the basic usage of PEFT model through `from_pretrained`. @@ -431,7 +438,10 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): """ for model_id in self.peft_test_model_ids: for transformers_class in self.transformers_test_model_classes: - peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto") + bnb_config = self._get_bnb_8bit_config() + peft_model = transformers_class.from_pretrained( + model_id, device_map="auto", quantization_config=bnb_config + ) module = peft_model.model.decoder.layers[0].self_attn.v_proj self.assertTrue(module.__class__.__name__ == "Linear8bitLt") @@ -449,7 +459,10 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): # 4bit for model_id in self.peft_test_model_ids: for transformers_class in self.transformers_test_model_classes: - peft_model = transformers_class.from_pretrained(model_id, load_in_4bit=True, device_map="auto") + bnb_config = self._get_bnb_4bit_config() + peft_model = transformers_class.from_pretrained( + model_id, device_map="auto", quantization_config=bnb_config + ) module = peft_model.model.decoder.layers[0].self_attn.v_proj self.assertTrue(module.__class__.__name__ == "Linear4bit") @@ -465,7 +478,10 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): # 8-bit for model_id in self.peft_test_model_ids: for transformers_class in self.transformers_test_model_classes: - peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto") + bnb_config = self._get_bnb_8bit_config() + peft_model = transformers_class.from_pretrained( + model_id, device_map="auto", quantization_config=bnb_config + ) module = peft_model.model.decoder.layers[0].self_attn.v_proj self.assertTrue(module.__class__.__name__ == "Linear8bitLt") @@ -489,7 +505,10 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): # 4bit for model_id in self.peft_test_model_ids: for transformers_class in self.transformers_test_model_classes: - peft_model = transformers_class.from_pretrained(model_id, load_in_4bit=True, device_map="auto") + bnb_config = self._get_bnb_4bit_config() + peft_model = transformers_class.from_pretrained( + model_id, device_map="auto", quantization_config=bnb_config + ) module = peft_model.model.decoder.layers[0].self_attn.v_proj self.assertTrue(module.__class__.__name__ == "Linear4bit") @@ -505,7 +524,10 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): # 8-bit for model_id in self.peft_test_model_ids: for transformers_class in self.transformers_test_model_classes: - peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto") + bnb_config = self._get_bnb_8bit_config() + peft_model = transformers_class.from_pretrained( + model_id, device_map="auto", quantization_config=bnb_config + ) module = peft_model.model.decoder.layers[0].self_attn.v_proj self.assertTrue(module.__class__.__name__ == "Linear8bitLt")