remove triton_kernels dep with kernels instead (#39926)

* remove dep

* style

* rm import

* fix

* style

* simplify

* style
This commit is contained in:
Marc Sun
2025-08-06 19:31:20 +02:00
committed by GitHub
parent cb2e0df2ec
commit 6902ffa505
6 changed files with 58 additions and 41 deletions

View File

@@ -49,7 +49,7 @@ FP4_VALUES = [
# Copied from GPT_OSS repo and vllm
def quantize_to_mxfp4(w):
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
downcast_to_mxfp = triton_kernels_hub.numerics_details.mxfp.downcast_to_mxfp
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
w, w_scale = swizzle_mxfp4(w, w_scale)
@@ -57,9 +57,13 @@ def quantize_to_mxfp4(w):
def swizzle_mxfp4(w, w_scale):
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.tensor_details.layout import StridedLayout
FP4, convert_layout, wrap_torch_tensor = (
triton_kernels_hub.tensor.FP4,
triton_kernels_hub.tensor.convert_layout,
triton_kernels_hub.tensor.wrap_torch_tensor,
)
layout = triton_kernels_hub.tensor_details.layout
StridedLayout = triton_kernels_hub.tensor_details.layout.StridedLayout
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts)
@@ -173,8 +177,12 @@ class Mxfp4GptOssExperts(nn.Module):
self.down_proj_precision_config = None
def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor:
from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs
from triton_kernels.swiglu import swiglu_fn
FnSpecs, FusedActivation, matmul_ogs = (
triton_kernels_hub.matmul_ogs.FnSpecs,
triton_kernels_hub.matmul_ogs.FusedActivation,
triton_kernels_hub.matmul_ogs.matmul_ogs,
)
swiglu_fn = triton_kernels_hub.swiglu.swiglu_fn
with torch.cuda.device(hidden_states.device):
act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, None), 2)
@@ -211,7 +219,12 @@ def routing_torch_dist(
):
import os
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch
GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch = (
triton_kernels_hub.routing.GatherIndx,
triton_kernels_hub.routing.RoutingData,
triton_kernels_hub.routing.ScatterIndx,
triton_kernels_hub.routing.compute_expt_data_torch,
)
with torch.cuda.device(logits.device):
world_size = torch.distributed.get_world_size()
@@ -274,7 +287,7 @@ def mlp_forward(self, hidden_states):
if dist.is_available() and dist.is_initialized():
routing = routing_torch_dist
else:
from triton_kernels.routing import routing
routing = triton_kernels_hub.routing.routing
routing = routing
batch_size = hidden_states.shape[0]
@@ -337,8 +350,11 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, **
def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, **kwargs):
from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig
PrecisionConfig, FlexCtx, InFlexData = (
triton_kernels_hub.matmul_ogs.PrecisionConfig,
triton_kernels_hub.matmul_ogs.FlexCtx,
triton_kernels_hub.matmul_ogs.InFlexData,
)
from ..integrations.tensor_parallel import shard_and_distribute_module
model = kwargs.get("model", None)
@@ -450,6 +466,11 @@ def replace_with_mxfp4_linear(
):
if quantization_config.dequantize:
return model
else:
from kernels import get_kernel
global triton_kernels_hub
triton_kernels_hub = get_kernel("kernels-community/triton_kernels")
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert

View File

@@ -21,9 +21,9 @@ if TYPE_CHECKING:
from ..utils import (
is_accelerate_available,
is_kernels_available,
is_torch_available,
is_triton_available,
is_triton_kernels_availalble,
logging,
)
from .quantizers_utils import get_module_from_name
@@ -68,7 +68,7 @@ class Mxfp4HfQuantizer(HfQuantizer):
compute_capability = torch.cuda.get_device_capability()
gpu_is_supported = compute_capability >= (7, 5)
kernels_available = is_triton_available("3.4.0") and is_triton_kernels_availalble()
kernels_available = is_triton_available("3.4.0") and is_kernels_available()
if self.pre_quantized:
# On unsupported GPUs or without kernels, we will dequantize the model to bf16
@@ -82,7 +82,7 @@ class Mxfp4HfQuantizer(HfQuantizer):
if not kernels_available:
logger.warning_once(
"MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed, we will default to dequantizing the model to bf16"
"MXFP4 quantization requires triton >= 3.4.0 and kernels installed, we will default to dequantizing the model to bf16"
)
self.quantization_config.dequantize = True
return
@@ -95,6 +95,12 @@ class Mxfp4HfQuantizer(HfQuantizer):
# we can't quantize the model in this case so we raise an error
raise ValueError("MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed")
if not self.pre_quantized:
from kernels import get_kernel
global triton_kernels_hub
triton_kernels_hub = get_kernel("kernels-community/triton_kernels")
device_map = kwargs.get("device_map", None)
if device_map is None:
logger.warning_once(
@@ -160,13 +166,15 @@ class Mxfp4HfQuantizer(HfQuantizer):
unexpected_keys: Optional[list[str]] = None,
**kwargs,
):
if is_triton_kernels_availalble() and is_triton_available("3.4.0"):
from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig
from ..integrations import Mxfp4GptOssExperts, dequantize, load_and_swizzle_mxfp4, quantize_to_mxfp4
from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts
if not self.pre_quantized:
PrecisionConfig, FlexCtx, InFlexData = (
triton_kernels_hub.matmul_ogs.PrecisionConfig,
triton_kernels_hub.matmul_ogs.FlexCtx,
triton_kernels_hub.matmul_ogs.InFlexData,
)
module, _ = get_module_from_name(model, param_name)
with torch.cuda.device(target_device):
if isinstance(module, Mxfp4GptOssExperts):

View File

@@ -170,7 +170,6 @@ from .utils import (
is_torchdynamo_available,
is_torchvision_available,
is_triton_available,
is_triton_kernels_availalble,
is_vision_available,
is_vptq_available,
strtobool,
@@ -471,13 +470,6 @@ def require_triton(min_version: str = TRITON_MIN_VERSION):
return decorator
def require_triton_kernels(test_case):
"""
Decorator marking a test that requires triton_kernels. These tests are skipped when triton_kernels isn't installed.
"""
return unittest.skipUnless(is_triton_kernels_availalble(), "test requires triton_kernels")(test_case)
def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION):
"""
Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed.

View File

@@ -270,7 +270,6 @@ from .import_utils import (
is_torchvision_v2_available,
is_training_run_on_sagemaker,
is_triton_available,
is_triton_kernels_availalble,
is_uroman_available,
is_vision_available,
is_vptq_available,

View File

@@ -238,7 +238,6 @@ _kernels_available = _is_package_available("kernels")
_matplotlib_available = _is_package_available("matplotlib")
_mistral_common_available = _is_package_available("mistral_common")
_triton_available, _triton_version = _is_package_available("triton", return_version=True)
_triton_kernels_available = _is_package_available("triton_kernels")
_torch_version = "N/A"
_torch_available = False
@@ -423,10 +422,6 @@ def is_triton_available(min_version: str = TRITON_MIN_VERSION):
return _triton_available and version.parse(_triton_version) >= version.parse(min_version)
def is_triton_kernels_availalble():
return _triton_kernels_available
def is_hadamard_available():
return _hadamard_available

View File

@@ -18,11 +18,11 @@ from unittest.mock import patch
from transformers import AutoTokenizer, GptOssForCausalLM, Mxfp4Config
from transformers.testing_utils import (
require_kernels,
require_torch,
require_torch_gpu,
require_torch_large_gpu,
require_triton,
require_triton_kernels,
slow,
)
from transformers.utils import (
@@ -194,7 +194,7 @@ class Mxfp4QuantizerTest(unittest.TestCase):
"""Test quantizer validation when triton is not available"""
with (
patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False),
patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False),
patch("transformers.quantizers.quantizer_mxfp4.is_kernels_availalble", return_value=False),
):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
@@ -208,7 +208,7 @@ class Mxfp4QuantizerTest(unittest.TestCase):
"""Test quantizer validation when triton is not available but model is pre-quantized and dequantize is False"""
with (
patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False),
patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False),
patch("transformers.quantizers.quantizer_mxfp4.is_kernels_availalble", return_value=False),
):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
@@ -348,7 +348,7 @@ class Mxfp4IntegrationTest(unittest.TestCase):
self.assertEqual(result.dtype, torch.bfloat16)
@require_triton(min_version="3.4.0")
@require_triton_kernels
@require_kernels
@require_torch_gpu
@require_torch
def test_quantize_to_mxfp4(self):
@@ -368,12 +368,14 @@ class Mxfp4IntegrationTest(unittest.TestCase):
@require_torch
@require_torch_large_gpu
@require_triton(min_version="3.4.0")
@require_kernels
@slow
class Mxfp4ModelTest(unittest.TestCase):
"""Test mxfp4 with actual models (requires specific model and hardware)"""
# These should be paths to real OpenAI MoE models for proper testing
model_name_packed = "/fsx/mohamed/oai-hf/tests/20b_converted_packed" # TODO: Use real packed quantized model
model_name = "openai/gpt-oss-20b"
input_text = "Once upon a time"
@@ -421,12 +423,12 @@ class Mxfp4ModelTest(unittest.TestCase):
self.assertFalse(quantization_config.dequantize)
model = GptOssForCausalLM.from_pretrained(
self.model_name_packed,
self.model_name,
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.check_inference_correctness_quantized(model, tokenizer)
def test_gpt_oss_model_loading_dequantized_with_device_map(self):
@@ -438,12 +440,12 @@ class Mxfp4ModelTest(unittest.TestCase):
self.assertTrue(quantization_config.dequantize)
model = GptOssForCausalLM.from_pretrained(
self.model_name_packed,
self.model_name,
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.check_inference_correctness_quantized(model, tokenizer)
def test_model_device_map_validation(self):
@@ -464,12 +466,12 @@ class Mxfp4ModelTest(unittest.TestCase):
# Expected: quantized < dequantized < unquantized memory usage
quantization_config = Mxfp4Config(dequantize=True)
quantized_model = GptOssForCausalLM.from_pretrained(
self.model_name_packed,
self.model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
)
dequantized_model = GptOssForCausalLM.from_pretrained(
self.model_name_packed,
self.model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=quantization_config,