Enable gpt-oss mxfp4 on older hardware (sm75+) (#39940)

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Matthew Douglas
2025-08-06 09:39:21 -04:00
committed by GitHub
parent dd70a8cb9d
commit c7844c7a8e
3 changed files with 40 additions and 14 deletions

View File

@@ -107,18 +107,31 @@ class Mxfp4QuantizerTest(unittest.TestCase):
def test_quantizer_validation_low_compute_capability(self):
"""Test quantizer validation with low compute capability"""
with patch("torch.cuda.get_device_capability", return_value=(8, 0)):
with patch("torch.cuda.get_device_capability", return_value=(7, 0)):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
config = Mxfp4Config()
quantizer = Mxfp4HfQuantizer(config)
quantizer.pre_quantized = False
with self.assertRaises(ValueError):
quantizer.validate_environment()
def test_quantizer_validation_low_compute_capability_with_prequantized(self):
"""Test quantizer validation with low compute capability"""
with patch("torch.cuda.get_device_capability", return_value=(7, 0)):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
config = Mxfp4Config()
quantizer = Mxfp4HfQuantizer(config)
with self.assertRaises(ValueError):
quantizer.validate_environment()
# Should automatically set dequantize=True and warn
quantizer.validate_environment()
self.assertTrue(quantizer.quantization_config.dequantize)
def test_quantizer_validation_low_compute_capability_with_dequantize(self):
"""Test quantizer validation with low compute capability but dequantize enabled"""
with patch("torch.cuda.get_device_capability", return_value=(8, 0)):
with patch("torch.cuda.get_device_capability", return_value=(7, 0)):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
config = Mxfp4Config(dequantize=True)