FEAT / Bitsandbytes: Add dequantize API for bitsandbytes quantized models (#30806)

* add  method

* change method name

* more comments

* Apply suggestions from code review

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fixup

* add docstrings and fix comment

* warn users on the de-quantized dtype

* Update src/transformers/quantizers/base.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/integrations/bitsandbytes.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* final suggestion - use private method

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2024-05-15 17:17:09 +02:00
committed by GitHub
parent 58faa7b824
commit 3f435823e0
9 changed files with 244 additions and 1 deletions

View File

@@ -285,6 +285,23 @@ class MixedInt8Test(BaseMixedInt8Test):
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_generate_quality_dequantize(self):
r"""
Test that loading the model and dequantizing it produce correct results
"""
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = AutoModelForCausalLM.from_pretrained(
self.model_name, quantization_config=bnb_config, device_map="auto"
)
model_8bit.dequantize()
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_8bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_raise_if_config_and_load_in_8bit(self):
r"""
Test that loading the model with the config and `load_in_8bit` raises an error