Remove residual quantization attribute from dequantized models (#39373)

* fix: removing quantization trace attribute from dequantized model

Fixes #39295

* add: test `to(dtype=torch.float16)` after dequantization
This commit is contained in:
Dario Salvati
2025-07-15 17:16:10 +02:00
committed by GitHub
parent 30c508dbcb
commit 67f42928f0
2 changed files with 28 additions and 0 deletions

View File

@@ -271,6 +271,33 @@ class Bnb4BitTest(Base4bitTest):
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_clear_quantization_trace(self):
r"""
Test that dequantizing the model won't leave any attribute relative to quantization in the model's configuration
"""
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
model_4bit = AutoModelForCausalLM.from_pretrained(
self.model_name, quantization_config=bnb_config, device_map="auto"
)
model_4bit.dequantize()
self.assertFalse(hasattr(model_4bit, "hf_quantizer"))
self.assertFalse(hasattr(model_4bit.config, "quantization_config"))
self.assertFalse(hasattr(model_4bit.config, "_pre_quantization_dtype"))
self.assertFalse(hasattr(model_4bit, "quantization_method"))
self.assertFalse(model_4bit.is_quantized)
def test_to_device_dequantized(self):
r"""
Test that dequantizing the model won't prevent converting it to a different dtype
"""
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
model_4bit = AutoModelForCausalLM.from_pretrained(
self.model_name, quantization_config=bnb_config, device_map="auto"
)
model_4bit.dequantize()
model_4bit.to(dtype=torch.float16)
def test_device_assignment(self):
if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
self.skipTest(reason="This test requires bitsandbytes >= 0.43.2")