Don't allow passing load_in_8bit and load_in_4bit at the same time (#28266)

* Update quantization_config.py

* Style

* Protect from setting directly

* add tests

* Update tests/quantization/bnb/test_4bit.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
Omar Sanseviero
2024-01-30 01:43:40 +01:00
committed by GitHub
parent cd2eb8cb2b
commit a989c6c6eb
2 changed files with 41 additions and 2 deletions

View File

@@ -648,3 +648,18 @@ class GPTSerializationTest(BaseSerializationTest):
"""
model_name = "gpt2-xl"
@require_bitsandbytes
@require_accelerate
@require_torch_gpu
@slow
class Bnb4BitTestBasicConfigTest(unittest.TestCase):
def test_load_in_4_and_8_bit_fails(self):
with self.assertRaisesRegex(ValueError, "load_in_4bit and load_in_8bit are both True"):
AutoModelForCausalLM.from_pretrained("facebook/opt-125m", load_in_4bit=True, load_in_8bit=True)
def test_set_load_in_8_bit(self):
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
with self.assertRaisesRegex(ValueError, "load_in_4bit and load_in_8bit are both True"):
quantization_config.load_in_8bit = True