TST Fix PEFT integration test bitsandbytes config (#39082)
TST Fix PEFT integration test bitsandbytes config
The PEFT integration tests still used load_in_{4,8}_bit, which is
deprecated, moving to properly setting BitsAndBytesConfig. For 4bit,
also ensure that nf4 is being used to prevent
> RuntimeError: quant_type must be nf4 on CPU, got fp4
This commit is contained in:
@@ -25,6 +25,7 @@ from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
OPTForCausalLM,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
@@ -76,6 +77,12 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
|
||||
return is_peft_loaded
|
||||
|
||||
def _get_bnb_4bit_config(self):
|
||||
return BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
|
||||
|
||||
def _get_bnb_8bit_config(self):
|
||||
return BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
def test_peft_from_pretrained(self):
|
||||
"""
|
||||
Simple test that tests the basic usage of PEFT model through `from_pretrained`.
|
||||
@@ -431,7 +438,10 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
"""
|
||||
for model_id in self.peft_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
|
||||
bnb_config = self._get_bnb_8bit_config()
|
||||
peft_model = transformers_class.from_pretrained(
|
||||
model_id, device_map="auto", quantization_config=bnb_config
|
||||
)
|
||||
|
||||
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
||||
self.assertTrue(module.__class__.__name__ == "Linear8bitLt")
|
||||
@@ -449,7 +459,10 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
# 4bit
|
||||
for model_id in self.peft_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
peft_model = transformers_class.from_pretrained(model_id, load_in_4bit=True, device_map="auto")
|
||||
bnb_config = self._get_bnb_4bit_config()
|
||||
peft_model = transformers_class.from_pretrained(
|
||||
model_id, device_map="auto", quantization_config=bnb_config
|
||||
)
|
||||
|
||||
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
||||
self.assertTrue(module.__class__.__name__ == "Linear4bit")
|
||||
@@ -465,7 +478,10 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
# 8-bit
|
||||
for model_id in self.peft_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
|
||||
bnb_config = self._get_bnb_8bit_config()
|
||||
peft_model = transformers_class.from_pretrained(
|
||||
model_id, device_map="auto", quantization_config=bnb_config
|
||||
)
|
||||
|
||||
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
||||
self.assertTrue(module.__class__.__name__ == "Linear8bitLt")
|
||||
@@ -489,7 +505,10 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
# 4bit
|
||||
for model_id in self.peft_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
peft_model = transformers_class.from_pretrained(model_id, load_in_4bit=True, device_map="auto")
|
||||
bnb_config = self._get_bnb_4bit_config()
|
||||
peft_model = transformers_class.from_pretrained(
|
||||
model_id, device_map="auto", quantization_config=bnb_config
|
||||
)
|
||||
|
||||
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
||||
self.assertTrue(module.__class__.__name__ == "Linear4bit")
|
||||
@@ -505,7 +524,10 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
# 8-bit
|
||||
for model_id in self.peft_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
|
||||
bnb_config = self._get_bnb_8bit_config()
|
||||
peft_model = transformers_class.from_pretrained(
|
||||
model_id, device_map="auto", quantization_config=bnb_config
|
||||
)
|
||||
|
||||
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
||||
self.assertTrue(module.__class__.__name__ == "Linear8bitLt")
|
||||
|
||||
Reference in New Issue
Block a user