From 6902ffa505970a39bd90f69b1afaa7affd564ac5 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Wed, 6 Aug 2025 19:31:20 +0200 Subject: [PATCH] remove `triton_kernels` dep with `kernels` instead (#39926) * remove dep * style * rm import * fix * style * simplify * style --- src/transformers/integrations/mxfp4.py | 41 ++++++++++++++----- .../quantizers/quantizer_mxfp4.py | 20 ++++++--- src/transformers/testing_utils.py | 8 ---- src/transformers/utils/__init__.py | 1 - src/transformers/utils/import_utils.py | 5 --- tests/quantization/mxfp4/test_mxfp4.py | 24 ++++++----- 6 files changed, 58 insertions(+), 41 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 393894c243..b37f72e7c3 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -49,7 +49,7 @@ FP4_VALUES = [ # Copied from GPT_OSS repo and vllm def quantize_to_mxfp4(w): - from triton_kernels.numerics_details.mxfp import downcast_to_mxfp + downcast_to_mxfp = triton_kernels_hub.numerics_details.mxfp.downcast_to_mxfp w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1) w, w_scale = swizzle_mxfp4(w, w_scale) @@ -57,9 +57,13 @@ def quantize_to_mxfp4(w): def swizzle_mxfp4(w, w_scale): - from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor - from triton_kernels.tensor_details import layout - from triton_kernels.tensor_details.layout import StridedLayout + FP4, convert_layout, wrap_torch_tensor = ( + triton_kernels_hub.tensor.FP4, + triton_kernels_hub.tensor.convert_layout, + triton_kernels_hub.tensor.wrap_torch_tensor, + ) + layout = triton_kernels_hub.tensor_details.layout + StridedLayout = triton_kernels_hub.tensor_details.layout.StridedLayout value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts) @@ -173,8 +177,12 @@ class Mxfp4GptOssExperts(nn.Module): self.down_proj_precision_config = None def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor: - from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs - from triton_kernels.swiglu import swiglu_fn + FnSpecs, FusedActivation, matmul_ogs = ( + triton_kernels_hub.matmul_ogs.FnSpecs, + triton_kernels_hub.matmul_ogs.FusedActivation, + triton_kernels_hub.matmul_ogs.matmul_ogs, + ) + swiglu_fn = triton_kernels_hub.swiglu.swiglu_fn with torch.cuda.device(hidden_states.device): act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, None), 2) @@ -211,7 +219,12 @@ def routing_torch_dist( ): import os - from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch + GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch = ( + triton_kernels_hub.routing.GatherIndx, + triton_kernels_hub.routing.RoutingData, + triton_kernels_hub.routing.ScatterIndx, + triton_kernels_hub.routing.compute_expt_data_torch, + ) with torch.cuda.device(logits.device): world_size = torch.distributed.get_world_size() @@ -274,7 +287,7 @@ def mlp_forward(self, hidden_states): if dist.is_available() and dist.is_initialized(): routing = routing_torch_dist else: - from triton_kernels.routing import routing + routing = triton_kernels_hub.routing.routing routing = routing batch_size = hidden_states.shape[0] @@ -337,8 +350,11 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, ** def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, **kwargs): - from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig - + PrecisionConfig, FlexCtx, InFlexData = ( + triton_kernels_hub.matmul_ogs.PrecisionConfig, + triton_kernels_hub.matmul_ogs.FlexCtx, + triton_kernels_hub.matmul_ogs.InFlexData, + ) from ..integrations.tensor_parallel import shard_and_distribute_module model = kwargs.get("model", None) @@ -450,6 +466,11 @@ def replace_with_mxfp4_linear( ): if quantization_config.dequantize: return model + else: + from kernels import get_kernel + + global triton_kernels_hub + triton_kernels_hub = get_kernel("kernels-community/triton_kernels") modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index f68a00e897..5281d4d763 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -21,9 +21,9 @@ if TYPE_CHECKING: from ..utils import ( is_accelerate_available, + is_kernels_available, is_torch_available, is_triton_available, - is_triton_kernels_availalble, logging, ) from .quantizers_utils import get_module_from_name @@ -68,7 +68,7 @@ class Mxfp4HfQuantizer(HfQuantizer): compute_capability = torch.cuda.get_device_capability() gpu_is_supported = compute_capability >= (7, 5) - kernels_available = is_triton_available("3.4.0") and is_triton_kernels_availalble() + kernels_available = is_triton_available("3.4.0") and is_kernels_available() if self.pre_quantized: # On unsupported GPUs or without kernels, we will dequantize the model to bf16 @@ -82,7 +82,7 @@ class Mxfp4HfQuantizer(HfQuantizer): if not kernels_available: logger.warning_once( - "MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed, we will default to dequantizing the model to bf16" + "MXFP4 quantization requires triton >= 3.4.0 and kernels installed, we will default to dequantizing the model to bf16" ) self.quantization_config.dequantize = True return @@ -95,6 +95,12 @@ class Mxfp4HfQuantizer(HfQuantizer): # we can't quantize the model in this case so we raise an error raise ValueError("MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed") + if not self.pre_quantized: + from kernels import get_kernel + + global triton_kernels_hub + triton_kernels_hub = get_kernel("kernels-community/triton_kernels") + device_map = kwargs.get("device_map", None) if device_map is None: logger.warning_once( @@ -160,13 +166,15 @@ class Mxfp4HfQuantizer(HfQuantizer): unexpected_keys: Optional[list[str]] = None, **kwargs, ): - if is_triton_kernels_availalble() and is_triton_available("3.4.0"): - from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig - from ..integrations import Mxfp4GptOssExperts, dequantize, load_and_swizzle_mxfp4, quantize_to_mxfp4 from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts if not self.pre_quantized: + PrecisionConfig, FlexCtx, InFlexData = ( + triton_kernels_hub.matmul_ogs.PrecisionConfig, + triton_kernels_hub.matmul_ogs.FlexCtx, + triton_kernels_hub.matmul_ogs.InFlexData, + ) module, _ = get_module_from_name(model, param_name) with torch.cuda.device(target_device): if isinstance(module, Mxfp4GptOssExperts): diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 7eb8b3c46e..e75bbab5b0 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -170,7 +170,6 @@ from .utils import ( is_torchdynamo_available, is_torchvision_available, is_triton_available, - is_triton_kernels_availalble, is_vision_available, is_vptq_available, strtobool, @@ -471,13 +470,6 @@ def require_triton(min_version: str = TRITON_MIN_VERSION): return decorator -def require_triton_kernels(test_case): - """ - Decorator marking a test that requires triton_kernels. These tests are skipped when triton_kernels isn't installed. - """ - return unittest.skipUnless(is_triton_kernels_availalble(), "test requires triton_kernels")(test_case) - - def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION): """ Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed. diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index f53914baa5..c28ae9a5b1 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -270,7 +270,6 @@ from .import_utils import ( is_torchvision_v2_available, is_training_run_on_sagemaker, is_triton_available, - is_triton_kernels_availalble, is_uroman_available, is_vision_available, is_vptq_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index a858c94c69..da740e68de 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -238,7 +238,6 @@ _kernels_available = _is_package_available("kernels") _matplotlib_available = _is_package_available("matplotlib") _mistral_common_available = _is_package_available("mistral_common") _triton_available, _triton_version = _is_package_available("triton", return_version=True) -_triton_kernels_available = _is_package_available("triton_kernels") _torch_version = "N/A" _torch_available = False @@ -423,10 +422,6 @@ def is_triton_available(min_version: str = TRITON_MIN_VERSION): return _triton_available and version.parse(_triton_version) >= version.parse(min_version) -def is_triton_kernels_availalble(): - return _triton_kernels_available - - def is_hadamard_available(): return _hadamard_available diff --git a/tests/quantization/mxfp4/test_mxfp4.py b/tests/quantization/mxfp4/test_mxfp4.py index ca14d86a34..1743891f8b 100644 --- a/tests/quantization/mxfp4/test_mxfp4.py +++ b/tests/quantization/mxfp4/test_mxfp4.py @@ -18,11 +18,11 @@ from unittest.mock import patch from transformers import AutoTokenizer, GptOssForCausalLM, Mxfp4Config from transformers.testing_utils import ( + require_kernels, require_torch, require_torch_gpu, require_torch_large_gpu, require_triton, - require_triton_kernels, slow, ) from transformers.utils import ( @@ -194,7 +194,7 @@ class Mxfp4QuantizerTest(unittest.TestCase): """Test quantizer validation when triton is not available""" with ( patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False), - patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False), + patch("transformers.quantizers.quantizer_mxfp4.is_kernels_availalble", return_value=False), ): from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer @@ -208,7 +208,7 @@ class Mxfp4QuantizerTest(unittest.TestCase): """Test quantizer validation when triton is not available but model is pre-quantized and dequantize is False""" with ( patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False), - patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False), + patch("transformers.quantizers.quantizer_mxfp4.is_kernels_availalble", return_value=False), ): from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer @@ -348,7 +348,7 @@ class Mxfp4IntegrationTest(unittest.TestCase): self.assertEqual(result.dtype, torch.bfloat16) @require_triton(min_version="3.4.0") - @require_triton_kernels + @require_kernels @require_torch_gpu @require_torch def test_quantize_to_mxfp4(self): @@ -368,12 +368,14 @@ class Mxfp4IntegrationTest(unittest.TestCase): @require_torch @require_torch_large_gpu +@require_triton(min_version="3.4.0") +@require_kernels @slow class Mxfp4ModelTest(unittest.TestCase): """Test mxfp4 with actual models (requires specific model and hardware)""" # These should be paths to real OpenAI MoE models for proper testing - model_name_packed = "/fsx/mohamed/oai-hf/tests/20b_converted_packed" # TODO: Use real packed quantized model + model_name = "openai/gpt-oss-20b" input_text = "Once upon a time" @@ -421,12 +423,12 @@ class Mxfp4ModelTest(unittest.TestCase): self.assertFalse(quantization_config.dequantize) model = GptOssForCausalLM.from_pretrained( - self.model_name_packed, + self.model_name, quantization_config=quantization_config, torch_dtype=torch.bfloat16, device_map="auto", ) - tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed) + tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.check_inference_correctness_quantized(model, tokenizer) def test_gpt_oss_model_loading_dequantized_with_device_map(self): @@ -438,12 +440,12 @@ class Mxfp4ModelTest(unittest.TestCase): self.assertTrue(quantization_config.dequantize) model = GptOssForCausalLM.from_pretrained( - self.model_name_packed, + self.model_name, quantization_config=quantization_config, torch_dtype=torch.bfloat16, device_map="auto", ) - tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed) + tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.check_inference_correctness_quantized(model, tokenizer) def test_model_device_map_validation(self): @@ -464,12 +466,12 @@ class Mxfp4ModelTest(unittest.TestCase): # Expected: quantized < dequantized < unquantized memory usage quantization_config = Mxfp4Config(dequantize=True) quantized_model = GptOssForCausalLM.from_pretrained( - self.model_name_packed, + self.model_name, torch_dtype=torch.bfloat16, device_map="auto", ) dequantized_model = GptOssForCausalLM.from_pretrained( - self.model_name_packed, + self.model_name, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config,