Remove residual quantization attribute from dequantized models (#39373)
* fix: removing quantization trace attribute from dequantized model Fixes #39295 * add: test `to(dtype=torch.float16)` after dequantization
This commit is contained in:
@@ -271,6 +271,33 @@ class Bnb4BitTest(Base4bitTest):
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
def test_clear_quantization_trace(self):
|
||||
r"""
|
||||
Test that dequantizing the model won't leave any attribute relative to quantization in the model's configuration
|
||||
"""
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
model_4bit = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, quantization_config=bnb_config, device_map="auto"
|
||||
)
|
||||
model_4bit.dequantize()
|
||||
|
||||
self.assertFalse(hasattr(model_4bit, "hf_quantizer"))
|
||||
self.assertFalse(hasattr(model_4bit.config, "quantization_config"))
|
||||
self.assertFalse(hasattr(model_4bit.config, "_pre_quantization_dtype"))
|
||||
self.assertFalse(hasattr(model_4bit, "quantization_method"))
|
||||
self.assertFalse(model_4bit.is_quantized)
|
||||
|
||||
def test_to_device_dequantized(self):
|
||||
r"""
|
||||
Test that dequantizing the model won't prevent converting it to a different dtype
|
||||
"""
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
model_4bit = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, quantization_config=bnb_config, device_map="auto"
|
||||
)
|
||||
model_4bit.dequantize()
|
||||
model_4bit.to(dtype=torch.float16)
|
||||
|
||||
def test_device_assignment(self):
|
||||
if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
|
||||
self.skipTest(reason="This test requires bitsandbytes >= 0.43.2")
|
||||
|
||||
Reference in New Issue
Block a user