Don't allow passing load_in_8bit and load_in_4bit at the same time (#28266)
* Update quantization_config.py * Style * Protect from setting directly * add tests * Update tests/quantization/bnb/test_4bit.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
@@ -212,8 +212,12 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.BITS_AND_BYTES
|
||||
self.load_in_8bit = load_in_8bit
|
||||
self.load_in_4bit = load_in_4bit
|
||||
|
||||
if load_in_4bit and load_in_8bit:
|
||||
raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
|
||||
|
||||
self._load_in_8bit = load_in_8bit
|
||||
self._load_in_4bit = load_in_4bit
|
||||
self.llm_int8_threshold = llm_int8_threshold
|
||||
self.llm_int8_skip_modules = llm_int8_skip_modules
|
||||
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
|
||||
@@ -232,6 +236,26 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
|
||||
|
||||
self.post_init()
|
||||
|
||||
@property
|
||||
def load_in_4bit(self):
|
||||
return self._load_in_4bit
|
||||
|
||||
@load_in_4bit.setter
|
||||
def load_in_4bit(self, value: bool):
|
||||
if self.load_in_8bit and value:
|
||||
raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
|
||||
self._load_in_4bit = value
|
||||
|
||||
@property
|
||||
def load_in_8bit(self):
|
||||
return self._load_in_8bit
|
||||
|
||||
@load_in_8bit.setter
|
||||
def load_in_8bit(self, value: bool):
|
||||
if self.load_in_4bit and value:
|
||||
raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
|
||||
self._load_in_8bit = value
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
|
||||
|
||||
@@ -648,3 +648,18 @@ class GPTSerializationTest(BaseSerializationTest):
|
||||
"""
|
||||
|
||||
model_name = "gpt2-xl"
|
||||
|
||||
|
||||
@require_bitsandbytes
|
||||
@require_accelerate
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
class Bnb4BitTestBasicConfigTest(unittest.TestCase):
|
||||
def test_load_in_4_and_8_bit_fails(self):
|
||||
with self.assertRaisesRegex(ValueError, "load_in_4bit and load_in_8bit are both True"):
|
||||
AutoModelForCausalLM.from_pretrained("facebook/opt-125m", load_in_4bit=True, load_in_8bit=True)
|
||||
|
||||
def test_set_load_in_8_bit(self):
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
with self.assertRaisesRegex(ValueError, "load_in_4bit and load_in_8bit are both True"):
|
||||
quantization_config.load_in_8bit = True
|
||||
|
||||
Reference in New Issue
Block a user