Don't allow passing load_in_8bit and load_in_4bit at the same time (#28266)

* Update quantization_config.py

* Style

* Protect from setting directly

* add tests

* Update tests/quantization/bnb/test_4bit.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
Omar Sanseviero
2024-01-30 01:43:40 +01:00
committed by GitHub
parent cd2eb8cb2b
commit a989c6c6eb
2 changed files with 41 additions and 2 deletions

View File

@@ -212,8 +212,12 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
**kwargs, **kwargs,
): ):
self.quant_method = QuantizationMethod.BITS_AND_BYTES self.quant_method = QuantizationMethod.BITS_AND_BYTES
self.load_in_8bit = load_in_8bit
self.load_in_4bit = load_in_4bit if load_in_4bit and load_in_8bit:
raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
self._load_in_8bit = load_in_8bit
self._load_in_4bit = load_in_4bit
self.llm_int8_threshold = llm_int8_threshold self.llm_int8_threshold = llm_int8_threshold
self.llm_int8_skip_modules = llm_int8_skip_modules self.llm_int8_skip_modules = llm_int8_skip_modules
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
@@ -232,6 +236,26 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
self.post_init() self.post_init()
@property
def load_in_4bit(self):
return self._load_in_4bit
@load_in_4bit.setter
def load_in_4bit(self, value: bool):
if self.load_in_8bit and value:
raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
self._load_in_4bit = value
@property
def load_in_8bit(self):
return self._load_in_8bit
@load_in_8bit.setter
def load_in_8bit(self, value: bool):
if self.load_in_4bit and value:
raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
self._load_in_8bit = value
def post_init(self): def post_init(self):
r""" r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.

View File

@@ -648,3 +648,18 @@ class GPTSerializationTest(BaseSerializationTest):
""" """
model_name = "gpt2-xl" model_name = "gpt2-xl"
@require_bitsandbytes
@require_accelerate
@require_torch_gpu
@slow
class Bnb4BitTestBasicConfigTest(unittest.TestCase):
def test_load_in_4_and_8_bit_fails(self):
with self.assertRaisesRegex(ValueError, "load_in_4bit and load_in_8bit are both True"):
AutoModelForCausalLM.from_pretrained("facebook/opt-125m", load_in_4bit=True, load_in_8bit=True)
def test_set_load_in_8_bit(self):
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
with self.assertRaisesRegex(ValueError, "load_in_4bit and load_in_8bit are both True"):
quantization_config.load_in_8bit = True