Fix MXFP4 quantizer validation to allow CPU inference with dequantize option (#39953)
* Fix MXFP4 quantizer validation to enable CPU dequantization Move dequantize check before CUDA availability check to allow CPU inference when quantization_config.dequantize is True. This enables users to run MXFP4 models on CPU by automatically converting them to BF16 format. * Add tests for MXFP4 quantizer CPU dequantization validation * fix: format mxfp4 test file with ruff
This commit is contained in:
@@ -131,6 +131,52 @@ class Mxfp4QuantizerTest(unittest.TestCase):
|
||||
if "compute capability" in str(e):
|
||||
self.fail("Should not raise compute capability error when dequantize=True")
|
||||
|
||||
def test_quantizer_validation_dequantize_on_cpu(self):
|
||||
"""Test quantizer validation with dequantize enabled on CPU-only environment"""
|
||||
with patch("torch.cuda.is_available", return_value=False):
|
||||
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
|
||||
|
||||
config = Mxfp4Config(dequantize=True)
|
||||
quantizer = Mxfp4HfQuantizer(config)
|
||||
|
||||
# Should not raise error when dequantize=True even without CUDA
|
||||
try:
|
||||
quantizer.validate_environment()
|
||||
except RuntimeError as e:
|
||||
if "requires a GPU" in str(e):
|
||||
self.fail("Should not raise GPU requirement error when dequantize=True on CPU")
|
||||
|
||||
def test_quantizer_validation_order_dequantize_before_cuda_check(self):
|
||||
"""Test that dequantize check happens before CUDA availability check"""
|
||||
# Mock both torch.cuda.is_available and is_accelerate_available to return False
|
||||
with (
|
||||
patch("torch.cuda.is_available", return_value=False),
|
||||
patch(
|
||||
"transformers.quantizers.quantizer_mxfp4.is_accelerate_available",
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
|
||||
|
||||
# Test with dequantize=True - should pass even without CUDA and accelerate
|
||||
config = Mxfp4Config(dequantize=True)
|
||||
quantizer = Mxfp4HfQuantizer(config)
|
||||
|
||||
# This should not raise any error because dequantize check comes first
|
||||
try:
|
||||
quantizer.validate_environment()
|
||||
except (RuntimeError, ImportError) as e:
|
||||
if "requires a GPU" in str(e) or "requires Accelerate" in str(e):
|
||||
self.fail(f"Should not raise error when dequantize=True: {e}")
|
||||
|
||||
# Test with dequantize=False - should still fail due to missing CUDA
|
||||
config = Mxfp4Config(dequantize=False)
|
||||
quantizer = Mxfp4HfQuantizer(config)
|
||||
|
||||
with self.assertRaises(RuntimeError) as context:
|
||||
quantizer.validate_environment()
|
||||
self.assertIn("requires a GPU", str(context.exception))
|
||||
|
||||
def test_quantizer_validation_missing_triton(self):
|
||||
"""Test quantizer validation when triton is not available"""
|
||||
with (
|
||||
|
||||
Reference in New Issue
Block a user