diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 061ca072f0..b3e52e5e1c 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -56,15 +56,16 @@ class Mxfp4HfQuantizer(HfQuantizer): "Using mxfp4 quantization requires torch" "Please install the latest version of torch ( pip install --upgrade torch )" ) + + if self.quantization_config.dequantize: + return + if not torch.cuda.is_available(): raise RuntimeError("Using MXFP4 quantized models requires a GPU") if not is_accelerate_available(): raise ImportError("Using mxfp4 requires Accelerate: `pip install accelerate`") - if self.quantization_config.dequantize: - return - compute_capability = torch.cuda.get_device_capability() major, minor = compute_capability diff --git a/tests/quantization/mxfp4/test_mxfp4.py b/tests/quantization/mxfp4/test_mxfp4.py index 2194c2d321..56268b44df 100644 --- a/tests/quantization/mxfp4/test_mxfp4.py +++ b/tests/quantization/mxfp4/test_mxfp4.py @@ -131,6 +131,52 @@ class Mxfp4QuantizerTest(unittest.TestCase): if "compute capability" in str(e): self.fail("Should not raise compute capability error when dequantize=True") + def test_quantizer_validation_dequantize_on_cpu(self): + """Test quantizer validation with dequantize enabled on CPU-only environment""" + with patch("torch.cuda.is_available", return_value=False): + from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer + + config = Mxfp4Config(dequantize=True) + quantizer = Mxfp4HfQuantizer(config) + + # Should not raise error when dequantize=True even without CUDA + try: + quantizer.validate_environment() + except RuntimeError as e: + if "requires a GPU" in str(e): + self.fail("Should not raise GPU requirement error when dequantize=True on CPU") + + def test_quantizer_validation_order_dequantize_before_cuda_check(self): + """Test that dequantize check happens before CUDA availability check""" + # Mock both torch.cuda.is_available and is_accelerate_available to return False + with ( + patch("torch.cuda.is_available", return_value=False), + patch( + "transformers.quantizers.quantizer_mxfp4.is_accelerate_available", + return_value=False, + ), + ): + from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer + + # Test with dequantize=True - should pass even without CUDA and accelerate + config = Mxfp4Config(dequantize=True) + quantizer = Mxfp4HfQuantizer(config) + + # This should not raise any error because dequantize check comes first + try: + quantizer.validate_environment() + except (RuntimeError, ImportError) as e: + if "requires a GPU" in str(e) or "requires Accelerate" in str(e): + self.fail(f"Should not raise error when dequantize=True: {e}") + + # Test with dequantize=False - should still fail due to missing CUDA + config = Mxfp4Config(dequantize=False) + quantizer = Mxfp4HfQuantizer(config) + + with self.assertRaises(RuntimeError) as context: + quantizer.validate_environment() + self.assertIn("requires a GPU", str(context.exception)) + def test_quantizer_validation_missing_triton(self): """Test quantizer validation when triton is not available""" with (