Fix MXFP4 quantizer validation to allow CPU inference with dequantize option (#39953)

* Fix MXFP4 quantizer validation to enable CPU dequantization

Move dequantize check before CUDA availability check to allow
CPU inference when quantization_config.dequantize is True.
This enables users to run MXFP4 models on CPU by automatically
converting them to BF16 format.

* Add tests for MXFP4 quantizer CPU dequantization validation

* fix: format mxfp4 test file with ruff
This commit is contained in:
Lintch
2025-08-06 21:20:41 +08:00
committed by GitHub
parent 82eb67e62a
commit dd70a8cb9d
2 changed files with 50 additions and 3 deletions

View File

@@ -56,15 +56,16 @@ class Mxfp4HfQuantizer(HfQuantizer):
"Using mxfp4 quantization requires torch"
"Please install the latest version of torch ( pip install --upgrade torch )"
)
if self.quantization_config.dequantize:
return
if not torch.cuda.is_available():
raise RuntimeError("Using MXFP4 quantized models requires a GPU")
if not is_accelerate_available():
raise ImportError("Using mxfp4 requires Accelerate: `pip install accelerate`")
if self.quantization_config.dequantize:
return
compute_capability = torch.cuda.get_device_capability()
major, minor = compute_capability

View File

@@ -131,6 +131,52 @@ class Mxfp4QuantizerTest(unittest.TestCase):
if "compute capability" in str(e):
self.fail("Should not raise compute capability error when dequantize=True")
def test_quantizer_validation_dequantize_on_cpu(self):
"""Test quantizer validation with dequantize enabled on CPU-only environment"""
with patch("torch.cuda.is_available", return_value=False):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
config = Mxfp4Config(dequantize=True)
quantizer = Mxfp4HfQuantizer(config)
# Should not raise error when dequantize=True even without CUDA
try:
quantizer.validate_environment()
except RuntimeError as e:
if "requires a GPU" in str(e):
self.fail("Should not raise GPU requirement error when dequantize=True on CPU")
def test_quantizer_validation_order_dequantize_before_cuda_check(self):
"""Test that dequantize check happens before CUDA availability check"""
# Mock both torch.cuda.is_available and is_accelerate_available to return False
with (
patch("torch.cuda.is_available", return_value=False),
patch(
"transformers.quantizers.quantizer_mxfp4.is_accelerate_available",
return_value=False,
),
):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
# Test with dequantize=True - should pass even without CUDA and accelerate
config = Mxfp4Config(dequantize=True)
quantizer = Mxfp4HfQuantizer(config)
# This should not raise any error because dequantize check comes first
try:
quantizer.validate_environment()
except (RuntimeError, ImportError) as e:
if "requires a GPU" in str(e) or "requires Accelerate" in str(e):
self.fail(f"Should not raise error when dequantize=True: {e}")
# Test with dequantize=False - should still fail due to missing CUDA
config = Mxfp4Config(dequantize=False)
quantizer = Mxfp4HfQuantizer(config)
with self.assertRaises(RuntimeError) as context:
quantizer.validate_environment()
self.assertIn("requires a GPU", str(context.exception))
def test_quantizer_validation_missing_triton(self):
"""Test quantizer validation when triton is not available"""
with (