Remove residual quantization attribute from dequantized models (#39373)
* fix: removing quantization trace attribute from dequantized model Fixes #39295 * add: test `to(dtype=torch.float16)` after dequantization
This commit is contained in:
@@ -248,6 +248,7 @@ class HfQuantizer(ABC):
|
|||||||
del model.hf_quantizer
|
del model.hf_quantizer
|
||||||
del model.config.quantization_config
|
del model.config.quantization_config
|
||||||
del model.config._pre_quantization_dtype
|
del model.config._pre_quantization_dtype
|
||||||
|
del model.quantization_method
|
||||||
model.is_quantized = False
|
model.is_quantized = False
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -271,6 +271,33 @@ class Bnb4BitTest(Base4bitTest):
|
|||||||
|
|
||||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||||
|
|
||||||
|
def test_clear_quantization_trace(self):
|
||||||
|
r"""
|
||||||
|
Test that dequantizing the model won't leave any attribute relative to quantization in the model's configuration
|
||||||
|
"""
|
||||||
|
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||||
|
model_4bit = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_name, quantization_config=bnb_config, device_map="auto"
|
||||||
|
)
|
||||||
|
model_4bit.dequantize()
|
||||||
|
|
||||||
|
self.assertFalse(hasattr(model_4bit, "hf_quantizer"))
|
||||||
|
self.assertFalse(hasattr(model_4bit.config, "quantization_config"))
|
||||||
|
self.assertFalse(hasattr(model_4bit.config, "_pre_quantization_dtype"))
|
||||||
|
self.assertFalse(hasattr(model_4bit, "quantization_method"))
|
||||||
|
self.assertFalse(model_4bit.is_quantized)
|
||||||
|
|
||||||
|
def test_to_device_dequantized(self):
|
||||||
|
r"""
|
||||||
|
Test that dequantizing the model won't prevent converting it to a different dtype
|
||||||
|
"""
|
||||||
|
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||||
|
model_4bit = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_name, quantization_config=bnb_config, device_map="auto"
|
||||||
|
)
|
||||||
|
model_4bit.dequantize()
|
||||||
|
model_4bit.to(dtype=torch.float16)
|
||||||
|
|
||||||
def test_device_assignment(self):
|
def test_device_assignment(self):
|
||||||
if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
|
if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
|
||||||
self.skipTest(reason="This test requires bitsandbytes >= 0.43.2")
|
self.skipTest(reason="This test requires bitsandbytes >= 0.43.2")
|
||||||
|
|||||||
Reference in New Issue
Block a user