Fix MXFP4 quantizer validation to allow CPU inference with dequantize option (#39953)

* Fix MXFP4 quantizer validation to enable CPU dequantization

Move dequantize check before CUDA availability check to allow
CPU inference when quantization_config.dequantize is True.
This enables users to run MXFP4 models on CPU by automatically
converting them to BF16 format.

* Add tests for MXFP4 quantizer CPU dequantization validation

* fix: format mxfp4 test file with ruff
This commit is contained in:
Lintch
2025-08-06 21:20:41 +08:00
committed by GitHub
parent 82eb67e62a
commit dd70a8cb9d
2 changed files with 50 additions and 3 deletions

View File

@@ -131,6 +131,52 @@ class Mxfp4QuantizerTest(unittest.TestCase):
if "compute capability" in str(e):
self.fail("Should not raise compute capability error when dequantize=True")
def test_quantizer_validation_dequantize_on_cpu(self):
"""Test quantizer validation with dequantize enabled on CPU-only environment"""
with patch("torch.cuda.is_available", return_value=False):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
config = Mxfp4Config(dequantize=True)
quantizer = Mxfp4HfQuantizer(config)
# Should not raise error when dequantize=True even without CUDA
try:
quantizer.validate_environment()
except RuntimeError as e:
if "requires a GPU" in str(e):
self.fail("Should not raise GPU requirement error when dequantize=True on CPU")
def test_quantizer_validation_order_dequantize_before_cuda_check(self):
"""Test that dequantize check happens before CUDA availability check"""
# Mock both torch.cuda.is_available and is_accelerate_available to return False
with (
patch("torch.cuda.is_available", return_value=False),
patch(
"transformers.quantizers.quantizer_mxfp4.is_accelerate_available",
return_value=False,
),
):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
# Test with dequantize=True - should pass even without CUDA and accelerate
config = Mxfp4Config(dequantize=True)
quantizer = Mxfp4HfQuantizer(config)
# This should not raise any error because dequantize check comes first
try:
quantizer.validate_environment()
except (RuntimeError, ImportError) as e:
if "requires a GPU" in str(e) or "requires Accelerate" in str(e):
self.fail(f"Should not raise error when dequantize=True: {e}")
# Test with dequantize=False - should still fail due to missing CUDA
config = Mxfp4Config(dequantize=False)
quantizer = Mxfp4HfQuantizer(config)
with self.assertRaises(RuntimeError) as context:
quantizer.validate_environment()
self.assertIn("requires a GPU", str(context.exception))
def test_quantizer_validation_missing_triton(self):
"""Test quantizer validation when triton is not available"""
with (