Enable gpt-oss mxfp4 on older hardware (sm75+) (#39940)
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -280,7 +280,10 @@ def mlp_forward(self, hidden_states):
|
|||||||
batch_size = hidden_states.shape[0]
|
batch_size = hidden_states.shape[0]
|
||||||
hidden_states = hidden_states.reshape(-1, self.router.hidden_dim)
|
hidden_states = hidden_states.reshape(-1, self.router.hidden_dim)
|
||||||
router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias)
|
router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias)
|
||||||
|
|
||||||
|
with torch.cuda.device(router_logits.device):
|
||||||
routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k)
|
routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k)
|
||||||
|
|
||||||
routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx)
|
routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx)
|
||||||
routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim)
|
routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim)
|
||||||
return routed_out, router_logits
|
return routed_out, router_logits
|
||||||
|
|||||||
@@ -67,24 +67,34 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
|||||||
raise ImportError("Using mxfp4 requires Accelerate: `pip install accelerate`")
|
raise ImportError("Using mxfp4 requires Accelerate: `pip install accelerate`")
|
||||||
|
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
major, minor = compute_capability
|
gpu_is_supported = compute_capability >= (7, 5)
|
||||||
|
kernels_available = is_triton_available("3.4.0") and is_triton_kernels_availalble()
|
||||||
|
|
||||||
if not is_triton_available("3.4.0") or not is_triton_kernels_availalble():
|
if self.pre_quantized:
|
||||||
if self.pre_quantized and not self.quantization_config.dequantize:
|
# On unsupported GPUs or without kernels, we will dequantize the model to bf16
|
||||||
|
if not gpu_is_supported:
|
||||||
|
logger.warning_once(
|
||||||
|
"MXFP4 quantization is only supported on GPUs with compute capability >= 7.5 (e.g T4, A100, L4, H100, or B200). "
|
||||||
|
"We will default to dequantizing the model to bf16."
|
||||||
|
)
|
||||||
|
self.quantization_config.dequantize = True
|
||||||
|
return
|
||||||
|
|
||||||
|
if not kernels_available:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed, we will default to dequantizing the model to bf16"
|
"MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed, we will default to dequantizing the model to bf16"
|
||||||
)
|
)
|
||||||
self.quantization_config.dequantize = True
|
self.quantization_config.dequantize = True
|
||||||
return
|
return
|
||||||
else:
|
elif not gpu_is_supported:
|
||||||
|
# we can't quantize the model in this case so we raise an error
|
||||||
|
raise ValueError(
|
||||||
|
"MXFP4 quantization is only supported on GPUs with compute capability >= 7.5 (e.g T4, A100, L4, H100, or B200)"
|
||||||
|
)
|
||||||
|
elif not kernels_available:
|
||||||
# we can't quantize the model in this case so we raise an error
|
# we can't quantize the model in this case so we raise an error
|
||||||
raise ValueError("MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed")
|
raise ValueError("MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed")
|
||||||
|
|
||||||
if major < 9:
|
|
||||||
raise ValueError(
|
|
||||||
"MXFP4 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100, or B100)"
|
|
||||||
)
|
|
||||||
|
|
||||||
device_map = kwargs.get("device_map", None)
|
device_map = kwargs.get("device_map", None)
|
||||||
if device_map is None:
|
if device_map is None:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
|
|||||||
@@ -107,18 +107,31 @@ class Mxfp4QuantizerTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_quantizer_validation_low_compute_capability(self):
|
def test_quantizer_validation_low_compute_capability(self):
|
||||||
"""Test quantizer validation with low compute capability"""
|
"""Test quantizer validation with low compute capability"""
|
||||||
with patch("torch.cuda.get_device_capability", return_value=(8, 0)):
|
with patch("torch.cuda.get_device_capability", return_value=(7, 0)):
|
||||||
|
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
|
||||||
|
|
||||||
|
config = Mxfp4Config()
|
||||||
|
quantizer = Mxfp4HfQuantizer(config)
|
||||||
|
quantizer.pre_quantized = False
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
quantizer.validate_environment()
|
||||||
|
|
||||||
|
def test_quantizer_validation_low_compute_capability_with_prequantized(self):
|
||||||
|
"""Test quantizer validation with low compute capability"""
|
||||||
|
with patch("torch.cuda.get_device_capability", return_value=(7, 0)):
|
||||||
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
|
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
|
||||||
|
|
||||||
config = Mxfp4Config()
|
config = Mxfp4Config()
|
||||||
quantizer = Mxfp4HfQuantizer(config)
|
quantizer = Mxfp4HfQuantizer(config)
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
# Should automatically set dequantize=True and warn
|
||||||
quantizer.validate_environment()
|
quantizer.validate_environment()
|
||||||
|
self.assertTrue(quantizer.quantization_config.dequantize)
|
||||||
|
|
||||||
def test_quantizer_validation_low_compute_capability_with_dequantize(self):
|
def test_quantizer_validation_low_compute_capability_with_dequantize(self):
|
||||||
"""Test quantizer validation with low compute capability but dequantize enabled"""
|
"""Test quantizer validation with low compute capability but dequantize enabled"""
|
||||||
with patch("torch.cuda.get_device_capability", return_value=(8, 0)):
|
with patch("torch.cuda.get_device_capability", return_value=(7, 0)):
|
||||||
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
|
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
|
||||||
|
|
||||||
config = Mxfp4Config(dequantize=True)
|
config = Mxfp4Config(dequantize=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user