From 67f42928f0ec97a4635e7ff52a4b5e7879590c1c Mon Sep 17 00:00:00 2001 From: Dario Salvati Date: Tue, 15 Jul 2025 17:16:10 +0200 Subject: [PATCH] Remove residual quantization attribute from dequantized models (#39373) * fix: removing quantization trace attribute from dequantized model Fixes #39295 * add: test `to(dtype=torch.float16)` after dequantization --- src/transformers/quantizers/base.py | 1 + tests/quantization/bnb/test_4bit.py | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index d54c5f2966..0a4ddf6804 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -248,6 +248,7 @@ class HfQuantizer(ABC): del model.hf_quantizer del model.config.quantization_config del model.config._pre_quantization_dtype + del model.quantization_method model.is_quantized = False return model diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index fd72d13505..3c0977f21c 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -271,6 +271,33 @@ class Bnb4BitTest(Base4bitTest): self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) + def test_clear_quantization_trace(self): + r""" + Test that dequantizing the model won't leave any attribute relative to quantization in the model's configuration + """ + bnb_config = BitsAndBytesConfig(load_in_4bit=True) + model_4bit = AutoModelForCausalLM.from_pretrained( + self.model_name, quantization_config=bnb_config, device_map="auto" + ) + model_4bit.dequantize() + + self.assertFalse(hasattr(model_4bit, "hf_quantizer")) + self.assertFalse(hasattr(model_4bit.config, "quantization_config")) + self.assertFalse(hasattr(model_4bit.config, "_pre_quantization_dtype")) + self.assertFalse(hasattr(model_4bit, "quantization_method")) + self.assertFalse(model_4bit.is_quantized) + + def test_to_device_dequantized(self): + r""" + Test that dequantizing the model won't prevent converting it to a different dtype + """ + bnb_config = BitsAndBytesConfig(load_in_4bit=True) + model_4bit = AutoModelForCausalLM.from_pretrained( + self.model_name, quantization_config=bnb_config, device_map="auto" + ) + model_4bit.dequantize() + model_4bit.to(dtype=torch.float16) + def test_device_assignment(self): if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"): self.skipTest(reason="This test requires bitsandbytes >= 0.43.2")