Don't allow passing load_in_8bit and load_in_4bit at the same time (#28266)
* Update quantization_config.py * Style * Protect from setting directly * add tests * Update tests/quantization/bnb/test_4bit.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
@@ -648,3 +648,18 @@ class GPTSerializationTest(BaseSerializationTest):
|
||||
"""
|
||||
|
||||
model_name = "gpt2-xl"
|
||||
|
||||
|
||||
@require_bitsandbytes
|
||||
@require_accelerate
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
class Bnb4BitTestBasicConfigTest(unittest.TestCase):
|
||||
def test_load_in_4_and_8_bit_fails(self):
|
||||
with self.assertRaisesRegex(ValueError, "load_in_4bit and load_in_8bit are both True"):
|
||||
AutoModelForCausalLM.from_pretrained("facebook/opt-125m", load_in_4bit=True, load_in_8bit=True)
|
||||
|
||||
def test_set_load_in_8_bit(self):
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
with self.assertRaisesRegex(ValueError, "load_in_4bit and load_in_8bit are both True"):
|
||||
quantization_config.load_in_8bit = True
|
||||
|
||||
Reference in New Issue
Block a user