remove triton_kernels dep with kernels instead (#39926)

* remove dep

* style

* rm import

* fix

* style

* simplify

* style
This commit is contained in:
Marc Sun
2025-08-06 19:31:20 +02:00
committed by GitHub
parent cb2e0df2ec
commit 6902ffa505
6 changed files with 58 additions and 41 deletions

View File

@@ -49,7 +49,7 @@ FP4_VALUES = [
# Copied from GPT_OSS repo and vllm # Copied from GPT_OSS repo and vllm
def quantize_to_mxfp4(w): def quantize_to_mxfp4(w):
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp downcast_to_mxfp = triton_kernels_hub.numerics_details.mxfp.downcast_to_mxfp
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1) w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
w, w_scale = swizzle_mxfp4(w, w_scale) w, w_scale = swizzle_mxfp4(w, w_scale)
@@ -57,9 +57,13 @@ def quantize_to_mxfp4(w):
def swizzle_mxfp4(w, w_scale): def swizzle_mxfp4(w, w_scale):
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor FP4, convert_layout, wrap_torch_tensor = (
from triton_kernels.tensor_details import layout triton_kernels_hub.tensor.FP4,
from triton_kernels.tensor_details.layout import StridedLayout triton_kernels_hub.tensor.convert_layout,
triton_kernels_hub.tensor.wrap_torch_tensor,
)
layout = triton_kernels_hub.tensor_details.layout
StridedLayout = triton_kernels_hub.tensor_details.layout.StridedLayout
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts) w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts)
@@ -173,8 +177,12 @@ class Mxfp4GptOssExperts(nn.Module):
self.down_proj_precision_config = None self.down_proj_precision_config = None
def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor:
from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs FnSpecs, FusedActivation, matmul_ogs = (
from triton_kernels.swiglu import swiglu_fn triton_kernels_hub.matmul_ogs.FnSpecs,
triton_kernels_hub.matmul_ogs.FusedActivation,
triton_kernels_hub.matmul_ogs.matmul_ogs,
)
swiglu_fn = triton_kernels_hub.swiglu.swiglu_fn
with torch.cuda.device(hidden_states.device): with torch.cuda.device(hidden_states.device):
act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, None), 2) act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, None), 2)
@@ -211,7 +219,12 @@ def routing_torch_dist(
): ):
import os import os
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch = (
triton_kernels_hub.routing.GatherIndx,
triton_kernels_hub.routing.RoutingData,
triton_kernels_hub.routing.ScatterIndx,
triton_kernels_hub.routing.compute_expt_data_torch,
)
with torch.cuda.device(logits.device): with torch.cuda.device(logits.device):
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
@@ -274,7 +287,7 @@ def mlp_forward(self, hidden_states):
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
routing = routing_torch_dist routing = routing_torch_dist
else: else:
from triton_kernels.routing import routing routing = triton_kernels_hub.routing.routing
routing = routing routing = routing
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
@@ -337,8 +350,11 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, **
def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, **kwargs): def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, **kwargs):
from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig PrecisionConfig, FlexCtx, InFlexData = (
triton_kernels_hub.matmul_ogs.PrecisionConfig,
triton_kernels_hub.matmul_ogs.FlexCtx,
triton_kernels_hub.matmul_ogs.InFlexData,
)
from ..integrations.tensor_parallel import shard_and_distribute_module from ..integrations.tensor_parallel import shard_and_distribute_module
model = kwargs.get("model", None) model = kwargs.get("model", None)
@@ -450,6 +466,11 @@ def replace_with_mxfp4_linear(
): ):
if quantization_config.dequantize: if quantization_config.dequantize:
return model return model
else:
from kernels import get_kernel
global triton_kernels_hub
triton_kernels_hub = get_kernel("kernels-community/triton_kernels")
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert

View File

@@ -21,9 +21,9 @@ if TYPE_CHECKING:
from ..utils import ( from ..utils import (
is_accelerate_available, is_accelerate_available,
is_kernels_available,
is_torch_available, is_torch_available,
is_triton_available, is_triton_available,
is_triton_kernels_availalble,
logging, logging,
) )
from .quantizers_utils import get_module_from_name from .quantizers_utils import get_module_from_name
@@ -68,7 +68,7 @@ class Mxfp4HfQuantizer(HfQuantizer):
compute_capability = torch.cuda.get_device_capability() compute_capability = torch.cuda.get_device_capability()
gpu_is_supported = compute_capability >= (7, 5) gpu_is_supported = compute_capability >= (7, 5)
kernels_available = is_triton_available("3.4.0") and is_triton_kernels_availalble() kernels_available = is_triton_available("3.4.0") and is_kernels_available()
if self.pre_quantized: if self.pre_quantized:
# On unsupported GPUs or without kernels, we will dequantize the model to bf16 # On unsupported GPUs or without kernels, we will dequantize the model to bf16
@@ -82,7 +82,7 @@ class Mxfp4HfQuantizer(HfQuantizer):
if not kernels_available: if not kernels_available:
logger.warning_once( logger.warning_once(
"MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed, we will default to dequantizing the model to bf16" "MXFP4 quantization requires triton >= 3.4.0 and kernels installed, we will default to dequantizing the model to bf16"
) )
self.quantization_config.dequantize = True self.quantization_config.dequantize = True
return return
@@ -95,6 +95,12 @@ class Mxfp4HfQuantizer(HfQuantizer):
# we can't quantize the model in this case so we raise an error # we can't quantize the model in this case so we raise an error
raise ValueError("MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed") raise ValueError("MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed")
if not self.pre_quantized:
from kernels import get_kernel
global triton_kernels_hub
triton_kernels_hub = get_kernel("kernels-community/triton_kernels")
device_map = kwargs.get("device_map", None) device_map = kwargs.get("device_map", None)
if device_map is None: if device_map is None:
logger.warning_once( logger.warning_once(
@@ -160,13 +166,15 @@ class Mxfp4HfQuantizer(HfQuantizer):
unexpected_keys: Optional[list[str]] = None, unexpected_keys: Optional[list[str]] = None,
**kwargs, **kwargs,
): ):
if is_triton_kernels_availalble() and is_triton_available("3.4.0"):
from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig
from ..integrations import Mxfp4GptOssExperts, dequantize, load_and_swizzle_mxfp4, quantize_to_mxfp4 from ..integrations import Mxfp4GptOssExperts, dequantize, load_and_swizzle_mxfp4, quantize_to_mxfp4
from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts
if not self.pre_quantized: if not self.pre_quantized:
PrecisionConfig, FlexCtx, InFlexData = (
triton_kernels_hub.matmul_ogs.PrecisionConfig,
triton_kernels_hub.matmul_ogs.FlexCtx,
triton_kernels_hub.matmul_ogs.InFlexData,
)
module, _ = get_module_from_name(model, param_name) module, _ = get_module_from_name(model, param_name)
with torch.cuda.device(target_device): with torch.cuda.device(target_device):
if isinstance(module, Mxfp4GptOssExperts): if isinstance(module, Mxfp4GptOssExperts):

View File

@@ -170,7 +170,6 @@ from .utils import (
is_torchdynamo_available, is_torchdynamo_available,
is_torchvision_available, is_torchvision_available,
is_triton_available, is_triton_available,
is_triton_kernels_availalble,
is_vision_available, is_vision_available,
is_vptq_available, is_vptq_available,
strtobool, strtobool,
@@ -471,13 +470,6 @@ def require_triton(min_version: str = TRITON_MIN_VERSION):
return decorator return decorator
def require_triton_kernels(test_case):
"""
Decorator marking a test that requires triton_kernels. These tests are skipped when triton_kernels isn't installed.
"""
return unittest.skipUnless(is_triton_kernels_availalble(), "test requires triton_kernels")(test_case)
def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION): def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION):
""" """
Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed. Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed.

View File

@@ -270,7 +270,6 @@ from .import_utils import (
is_torchvision_v2_available, is_torchvision_v2_available,
is_training_run_on_sagemaker, is_training_run_on_sagemaker,
is_triton_available, is_triton_available,
is_triton_kernels_availalble,
is_uroman_available, is_uroman_available,
is_vision_available, is_vision_available,
is_vptq_available, is_vptq_available,

View File

@@ -238,7 +238,6 @@ _kernels_available = _is_package_available("kernels")
_matplotlib_available = _is_package_available("matplotlib") _matplotlib_available = _is_package_available("matplotlib")
_mistral_common_available = _is_package_available("mistral_common") _mistral_common_available = _is_package_available("mistral_common")
_triton_available, _triton_version = _is_package_available("triton", return_version=True) _triton_available, _triton_version = _is_package_available("triton", return_version=True)
_triton_kernels_available = _is_package_available("triton_kernels")
_torch_version = "N/A" _torch_version = "N/A"
_torch_available = False _torch_available = False
@@ -423,10 +422,6 @@ def is_triton_available(min_version: str = TRITON_MIN_VERSION):
return _triton_available and version.parse(_triton_version) >= version.parse(min_version) return _triton_available and version.parse(_triton_version) >= version.parse(min_version)
def is_triton_kernels_availalble():
return _triton_kernels_available
def is_hadamard_available(): def is_hadamard_available():
return _hadamard_available return _hadamard_available

View File

@@ -18,11 +18,11 @@ from unittest.mock import patch
from transformers import AutoTokenizer, GptOssForCausalLM, Mxfp4Config from transformers import AutoTokenizer, GptOssForCausalLM, Mxfp4Config
from transformers.testing_utils import ( from transformers.testing_utils import (
require_kernels,
require_torch, require_torch,
require_torch_gpu, require_torch_gpu,
require_torch_large_gpu, require_torch_large_gpu,
require_triton, require_triton,
require_triton_kernels,
slow, slow,
) )
from transformers.utils import ( from transformers.utils import (
@@ -194,7 +194,7 @@ class Mxfp4QuantizerTest(unittest.TestCase):
"""Test quantizer validation when triton is not available""" """Test quantizer validation when triton is not available"""
with ( with (
patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False), patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False),
patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False), patch("transformers.quantizers.quantizer_mxfp4.is_kernels_availalble", return_value=False),
): ):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
@@ -208,7 +208,7 @@ class Mxfp4QuantizerTest(unittest.TestCase):
"""Test quantizer validation when triton is not available but model is pre-quantized and dequantize is False""" """Test quantizer validation when triton is not available but model is pre-quantized and dequantize is False"""
with ( with (
patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False), patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False),
patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False), patch("transformers.quantizers.quantizer_mxfp4.is_kernels_availalble", return_value=False),
): ):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
@@ -348,7 +348,7 @@ class Mxfp4IntegrationTest(unittest.TestCase):
self.assertEqual(result.dtype, torch.bfloat16) self.assertEqual(result.dtype, torch.bfloat16)
@require_triton(min_version="3.4.0") @require_triton(min_version="3.4.0")
@require_triton_kernels @require_kernels
@require_torch_gpu @require_torch_gpu
@require_torch @require_torch
def test_quantize_to_mxfp4(self): def test_quantize_to_mxfp4(self):
@@ -368,12 +368,14 @@ class Mxfp4IntegrationTest(unittest.TestCase):
@require_torch @require_torch
@require_torch_large_gpu @require_torch_large_gpu
@require_triton(min_version="3.4.0")
@require_kernels
@slow @slow
class Mxfp4ModelTest(unittest.TestCase): class Mxfp4ModelTest(unittest.TestCase):
"""Test mxfp4 with actual models (requires specific model and hardware)""" """Test mxfp4 with actual models (requires specific model and hardware)"""
# These should be paths to real OpenAI MoE models for proper testing # These should be paths to real OpenAI MoE models for proper testing
model_name_packed = "/fsx/mohamed/oai-hf/tests/20b_converted_packed" # TODO: Use real packed quantized model model_name = "openai/gpt-oss-20b"
input_text = "Once upon a time" input_text = "Once upon a time"
@@ -421,12 +423,12 @@ class Mxfp4ModelTest(unittest.TestCase):
self.assertFalse(quantization_config.dequantize) self.assertFalse(quantization_config.dequantize)
model = GptOssForCausalLM.from_pretrained( model = GptOssForCausalLM.from_pretrained(
self.model_name_packed, self.model_name,
quantization_config=quantization_config, quantization_config=quantization_config,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device_map="auto", device_map="auto",
) )
tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed) tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.check_inference_correctness_quantized(model, tokenizer) self.check_inference_correctness_quantized(model, tokenizer)
def test_gpt_oss_model_loading_dequantized_with_device_map(self): def test_gpt_oss_model_loading_dequantized_with_device_map(self):
@@ -438,12 +440,12 @@ class Mxfp4ModelTest(unittest.TestCase):
self.assertTrue(quantization_config.dequantize) self.assertTrue(quantization_config.dequantize)
model = GptOssForCausalLM.from_pretrained( model = GptOssForCausalLM.from_pretrained(
self.model_name_packed, self.model_name,
quantization_config=quantization_config, quantization_config=quantization_config,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device_map="auto", device_map="auto",
) )
tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed) tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.check_inference_correctness_quantized(model, tokenizer) self.check_inference_correctness_quantized(model, tokenizer)
def test_model_device_map_validation(self): def test_model_device_map_validation(self):
@@ -464,12 +466,12 @@ class Mxfp4ModelTest(unittest.TestCase):
# Expected: quantized < dequantized < unquantized memory usage # Expected: quantized < dequantized < unquantized memory usage
quantization_config = Mxfp4Config(dequantize=True) quantization_config = Mxfp4Config(dequantize=True)
quantized_model = GptOssForCausalLM.from_pretrained( quantized_model = GptOssForCausalLM.from_pretrained(
self.model_name_packed, self.model_name,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device_map="auto", device_map="auto",
) )
dequantized_model = GptOssForCausalLM.from_pretrained( dequantized_model = GptOssForCausalLM.from_pretrained(
self.model_name_packed, self.model_name,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device_map="auto", device_map="auto",
quantization_config=quantization_config, quantization_config=quantization_config,