Fix MXFP4 quantizer validation to allow CPU inference with dequantize option (#39953)
* Fix MXFP4 quantizer validation to enable CPU dequantization Move dequantize check before CUDA availability check to allow CPU inference when quantization_config.dequantize is True. This enables users to run MXFP4 models on CPU by automatically converting them to BF16 format. * Add tests for MXFP4 quantizer CPU dequantization validation * fix: format mxfp4 test file with ruff
This commit is contained in:
@@ -56,15 +56,16 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
|||||||
"Using mxfp4 quantization requires torch"
|
"Using mxfp4 quantization requires torch"
|
||||||
"Please install the latest version of torch ( pip install --upgrade torch )"
|
"Please install the latest version of torch ( pip install --upgrade torch )"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.quantization_config.dequantize:
|
||||||
|
return
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
raise RuntimeError("Using MXFP4 quantized models requires a GPU")
|
raise RuntimeError("Using MXFP4 quantized models requires a GPU")
|
||||||
|
|
||||||
if not is_accelerate_available():
|
if not is_accelerate_available():
|
||||||
raise ImportError("Using mxfp4 requires Accelerate: `pip install accelerate`")
|
raise ImportError("Using mxfp4 requires Accelerate: `pip install accelerate`")
|
||||||
|
|
||||||
if self.quantization_config.dequantize:
|
|
||||||
return
|
|
||||||
|
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
major, minor = compute_capability
|
major, minor = compute_capability
|
||||||
|
|
||||||
|
|||||||
@@ -131,6 +131,52 @@ class Mxfp4QuantizerTest(unittest.TestCase):
|
|||||||
if "compute capability" in str(e):
|
if "compute capability" in str(e):
|
||||||
self.fail("Should not raise compute capability error when dequantize=True")
|
self.fail("Should not raise compute capability error when dequantize=True")
|
||||||
|
|
||||||
|
def test_quantizer_validation_dequantize_on_cpu(self):
|
||||||
|
"""Test quantizer validation with dequantize enabled on CPU-only environment"""
|
||||||
|
with patch("torch.cuda.is_available", return_value=False):
|
||||||
|
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
|
||||||
|
|
||||||
|
config = Mxfp4Config(dequantize=True)
|
||||||
|
quantizer = Mxfp4HfQuantizer(config)
|
||||||
|
|
||||||
|
# Should not raise error when dequantize=True even without CUDA
|
||||||
|
try:
|
||||||
|
quantizer.validate_environment()
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "requires a GPU" in str(e):
|
||||||
|
self.fail("Should not raise GPU requirement error when dequantize=True on CPU")
|
||||||
|
|
||||||
|
def test_quantizer_validation_order_dequantize_before_cuda_check(self):
|
||||||
|
"""Test that dequantize check happens before CUDA availability check"""
|
||||||
|
# Mock both torch.cuda.is_available and is_accelerate_available to return False
|
||||||
|
with (
|
||||||
|
patch("torch.cuda.is_available", return_value=False),
|
||||||
|
patch(
|
||||||
|
"transformers.quantizers.quantizer_mxfp4.is_accelerate_available",
|
||||||
|
return_value=False,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
|
||||||
|
|
||||||
|
# Test with dequantize=True - should pass even without CUDA and accelerate
|
||||||
|
config = Mxfp4Config(dequantize=True)
|
||||||
|
quantizer = Mxfp4HfQuantizer(config)
|
||||||
|
|
||||||
|
# This should not raise any error because dequantize check comes first
|
||||||
|
try:
|
||||||
|
quantizer.validate_environment()
|
||||||
|
except (RuntimeError, ImportError) as e:
|
||||||
|
if "requires a GPU" in str(e) or "requires Accelerate" in str(e):
|
||||||
|
self.fail(f"Should not raise error when dequantize=True: {e}")
|
||||||
|
|
||||||
|
# Test with dequantize=False - should still fail due to missing CUDA
|
||||||
|
config = Mxfp4Config(dequantize=False)
|
||||||
|
quantizer = Mxfp4HfQuantizer(config)
|
||||||
|
|
||||||
|
with self.assertRaises(RuntimeError) as context:
|
||||||
|
quantizer.validate_environment()
|
||||||
|
self.assertIn("requires a GPU", str(context.exception))
|
||||||
|
|
||||||
def test_quantizer_validation_missing_triton(self):
|
def test_quantizer_validation_missing_triton(self):
|
||||||
"""Test quantizer validation when triton is not available"""
|
"""Test quantizer validation when triton is not available"""
|
||||||
with (
|
with (
|
||||||
|
|||||||
Reference in New Issue
Block a user