remove triton_kernels dep with kernels instead (#39926)
* remove dep * style * rm import * fix * style * simplify * style
This commit is contained in:
@@ -49,7 +49,7 @@ FP4_VALUES = [
|
||||
|
||||
# Copied from GPT_OSS repo and vllm
|
||||
def quantize_to_mxfp4(w):
|
||||
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
|
||||
downcast_to_mxfp = triton_kernels_hub.numerics_details.mxfp.downcast_to_mxfp
|
||||
|
||||
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
|
||||
w, w_scale = swizzle_mxfp4(w, w_scale)
|
||||
@@ -57,9 +57,13 @@ def quantize_to_mxfp4(w):
|
||||
|
||||
|
||||
def swizzle_mxfp4(w, w_scale):
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.tensor_details.layout import StridedLayout
|
||||
FP4, convert_layout, wrap_torch_tensor = (
|
||||
triton_kernels_hub.tensor.FP4,
|
||||
triton_kernels_hub.tensor.convert_layout,
|
||||
triton_kernels_hub.tensor.wrap_torch_tensor,
|
||||
)
|
||||
layout = triton_kernels_hub.tensor_details.layout
|
||||
StridedLayout = triton_kernels_hub.tensor_details.layout.StridedLayout
|
||||
|
||||
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
||||
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts)
|
||||
@@ -173,8 +177,12 @@ class Mxfp4GptOssExperts(nn.Module):
|
||||
self.down_proj_precision_config = None
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor:
|
||||
from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs
|
||||
from triton_kernels.swiglu import swiglu_fn
|
||||
FnSpecs, FusedActivation, matmul_ogs = (
|
||||
triton_kernels_hub.matmul_ogs.FnSpecs,
|
||||
triton_kernels_hub.matmul_ogs.FusedActivation,
|
||||
triton_kernels_hub.matmul_ogs.matmul_ogs,
|
||||
)
|
||||
swiglu_fn = triton_kernels_hub.swiglu.swiglu_fn
|
||||
|
||||
with torch.cuda.device(hidden_states.device):
|
||||
act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, None), 2)
|
||||
@@ -211,7 +219,12 @@ def routing_torch_dist(
|
||||
):
|
||||
import os
|
||||
|
||||
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch
|
||||
GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch = (
|
||||
triton_kernels_hub.routing.GatherIndx,
|
||||
triton_kernels_hub.routing.RoutingData,
|
||||
triton_kernels_hub.routing.ScatterIndx,
|
||||
triton_kernels_hub.routing.compute_expt_data_torch,
|
||||
)
|
||||
|
||||
with torch.cuda.device(logits.device):
|
||||
world_size = torch.distributed.get_world_size()
|
||||
@@ -274,7 +287,7 @@ def mlp_forward(self, hidden_states):
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
routing = routing_torch_dist
|
||||
else:
|
||||
from triton_kernels.routing import routing
|
||||
routing = triton_kernels_hub.routing.routing
|
||||
|
||||
routing = routing
|
||||
batch_size = hidden_states.shape[0]
|
||||
@@ -337,8 +350,11 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, **
|
||||
|
||||
|
||||
def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, **kwargs):
|
||||
from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig
|
||||
|
||||
PrecisionConfig, FlexCtx, InFlexData = (
|
||||
triton_kernels_hub.matmul_ogs.PrecisionConfig,
|
||||
triton_kernels_hub.matmul_ogs.FlexCtx,
|
||||
triton_kernels_hub.matmul_ogs.InFlexData,
|
||||
)
|
||||
from ..integrations.tensor_parallel import shard_and_distribute_module
|
||||
|
||||
model = kwargs.get("model", None)
|
||||
@@ -450,6 +466,11 @@ def replace_with_mxfp4_linear(
|
||||
):
|
||||
if quantization_config.dequantize:
|
||||
return model
|
||||
else:
|
||||
from kernels import get_kernel
|
||||
|
||||
global triton_kernels_hub
|
||||
triton_kernels_hub = get_kernel("kernels-community/triton_kernels")
|
||||
|
||||
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
|
||||
|
||||
|
||||
@@ -21,9 +21,9 @@ if TYPE_CHECKING:
|
||||
|
||||
from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_kernels_available,
|
||||
is_torch_available,
|
||||
is_triton_available,
|
||||
is_triton_kernels_availalble,
|
||||
logging,
|
||||
)
|
||||
from .quantizers_utils import get_module_from_name
|
||||
@@ -68,7 +68,7 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
||||
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
gpu_is_supported = compute_capability >= (7, 5)
|
||||
kernels_available = is_triton_available("3.4.0") and is_triton_kernels_availalble()
|
||||
kernels_available = is_triton_available("3.4.0") and is_kernels_available()
|
||||
|
||||
if self.pre_quantized:
|
||||
# On unsupported GPUs or without kernels, we will dequantize the model to bf16
|
||||
@@ -82,7 +82,7 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
||||
|
||||
if not kernels_available:
|
||||
logger.warning_once(
|
||||
"MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed, we will default to dequantizing the model to bf16"
|
||||
"MXFP4 quantization requires triton >= 3.4.0 and kernels installed, we will default to dequantizing the model to bf16"
|
||||
)
|
||||
self.quantization_config.dequantize = True
|
||||
return
|
||||
@@ -95,6 +95,12 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
||||
# we can't quantize the model in this case so we raise an error
|
||||
raise ValueError("MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed")
|
||||
|
||||
if not self.pre_quantized:
|
||||
from kernels import get_kernel
|
||||
|
||||
global triton_kernels_hub
|
||||
triton_kernels_hub = get_kernel("kernels-community/triton_kernels")
|
||||
|
||||
device_map = kwargs.get("device_map", None)
|
||||
if device_map is None:
|
||||
logger.warning_once(
|
||||
@@ -160,13 +166,15 @@ class Mxfp4HfQuantizer(HfQuantizer):
|
||||
unexpected_keys: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if is_triton_kernels_availalble() and is_triton_available("3.4.0"):
|
||||
from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig
|
||||
|
||||
from ..integrations import Mxfp4GptOssExperts, dequantize, load_and_swizzle_mxfp4, quantize_to_mxfp4
|
||||
from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts
|
||||
|
||||
if not self.pre_quantized:
|
||||
PrecisionConfig, FlexCtx, InFlexData = (
|
||||
triton_kernels_hub.matmul_ogs.PrecisionConfig,
|
||||
triton_kernels_hub.matmul_ogs.FlexCtx,
|
||||
triton_kernels_hub.matmul_ogs.InFlexData,
|
||||
)
|
||||
module, _ = get_module_from_name(model, param_name)
|
||||
with torch.cuda.device(target_device):
|
||||
if isinstance(module, Mxfp4GptOssExperts):
|
||||
|
||||
@@ -170,7 +170,6 @@ from .utils import (
|
||||
is_torchdynamo_available,
|
||||
is_torchvision_available,
|
||||
is_triton_available,
|
||||
is_triton_kernels_availalble,
|
||||
is_vision_available,
|
||||
is_vptq_available,
|
||||
strtobool,
|
||||
@@ -471,13 +470,6 @@ def require_triton(min_version: str = TRITON_MIN_VERSION):
|
||||
return decorator
|
||||
|
||||
|
||||
def require_triton_kernels(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires triton_kernels. These tests are skipped when triton_kernels isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_triton_kernels_availalble(), "test requires triton_kernels")(test_case)
|
||||
|
||||
|
||||
def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION):
|
||||
"""
|
||||
Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed.
|
||||
|
||||
@@ -270,7 +270,6 @@ from .import_utils import (
|
||||
is_torchvision_v2_available,
|
||||
is_training_run_on_sagemaker,
|
||||
is_triton_available,
|
||||
is_triton_kernels_availalble,
|
||||
is_uroman_available,
|
||||
is_vision_available,
|
||||
is_vptq_available,
|
||||
|
||||
@@ -238,7 +238,6 @@ _kernels_available = _is_package_available("kernels")
|
||||
_matplotlib_available = _is_package_available("matplotlib")
|
||||
_mistral_common_available = _is_package_available("mistral_common")
|
||||
_triton_available, _triton_version = _is_package_available("triton", return_version=True)
|
||||
_triton_kernels_available = _is_package_available("triton_kernels")
|
||||
|
||||
_torch_version = "N/A"
|
||||
_torch_available = False
|
||||
@@ -423,10 +422,6 @@ def is_triton_available(min_version: str = TRITON_MIN_VERSION):
|
||||
return _triton_available and version.parse(_triton_version) >= version.parse(min_version)
|
||||
|
||||
|
||||
def is_triton_kernels_availalble():
|
||||
return _triton_kernels_available
|
||||
|
||||
|
||||
def is_hadamard_available():
|
||||
return _hadamard_available
|
||||
|
||||
|
||||
@@ -18,11 +18,11 @@ from unittest.mock import patch
|
||||
|
||||
from transformers import AutoTokenizer, GptOssForCausalLM, Mxfp4Config
|
||||
from transformers.testing_utils import (
|
||||
require_kernels,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_large_gpu,
|
||||
require_triton,
|
||||
require_triton_kernels,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils import (
|
||||
@@ -194,7 +194,7 @@ class Mxfp4QuantizerTest(unittest.TestCase):
|
||||
"""Test quantizer validation when triton is not available"""
|
||||
with (
|
||||
patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False),
|
||||
patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False),
|
||||
patch("transformers.quantizers.quantizer_mxfp4.is_kernels_availalble", return_value=False),
|
||||
):
|
||||
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
|
||||
|
||||
@@ -208,7 +208,7 @@ class Mxfp4QuantizerTest(unittest.TestCase):
|
||||
"""Test quantizer validation when triton is not available but model is pre-quantized and dequantize is False"""
|
||||
with (
|
||||
patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False),
|
||||
patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False),
|
||||
patch("transformers.quantizers.quantizer_mxfp4.is_kernels_availalble", return_value=False),
|
||||
):
|
||||
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
|
||||
|
||||
@@ -348,7 +348,7 @@ class Mxfp4IntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@require_triton(min_version="3.4.0")
|
||||
@require_triton_kernels
|
||||
@require_kernels
|
||||
@require_torch_gpu
|
||||
@require_torch
|
||||
def test_quantize_to_mxfp4(self):
|
||||
@@ -368,12 +368,14 @@ class Mxfp4IntegrationTest(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
@require_torch_large_gpu
|
||||
@require_triton(min_version="3.4.0")
|
||||
@require_kernels
|
||||
@slow
|
||||
class Mxfp4ModelTest(unittest.TestCase):
|
||||
"""Test mxfp4 with actual models (requires specific model and hardware)"""
|
||||
|
||||
# These should be paths to real OpenAI MoE models for proper testing
|
||||
model_name_packed = "/fsx/mohamed/oai-hf/tests/20b_converted_packed" # TODO: Use real packed quantized model
|
||||
model_name = "openai/gpt-oss-20b"
|
||||
|
||||
input_text = "Once upon a time"
|
||||
|
||||
@@ -421,12 +423,12 @@ class Mxfp4ModelTest(unittest.TestCase):
|
||||
self.assertFalse(quantization_config.dequantize)
|
||||
|
||||
model = GptOssForCausalLM.from_pretrained(
|
||||
self.model_name_packed,
|
||||
self.model_name,
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
self.check_inference_correctness_quantized(model, tokenizer)
|
||||
|
||||
def test_gpt_oss_model_loading_dequantized_with_device_map(self):
|
||||
@@ -438,12 +440,12 @@ class Mxfp4ModelTest(unittest.TestCase):
|
||||
self.assertTrue(quantization_config.dequantize)
|
||||
|
||||
model = GptOssForCausalLM.from_pretrained(
|
||||
self.model_name_packed,
|
||||
self.model_name,
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
self.check_inference_correctness_quantized(model, tokenizer)
|
||||
|
||||
def test_model_device_map_validation(self):
|
||||
@@ -464,12 +466,12 @@ class Mxfp4ModelTest(unittest.TestCase):
|
||||
# Expected: quantized < dequantized < unquantized memory usage
|
||||
quantization_config = Mxfp4Config(dequantize=True)
|
||||
quantized_model = GptOssForCausalLM.from_pretrained(
|
||||
self.model_name_packed,
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
dequantized_model = GptOssForCausalLM.from_pretrained(
|
||||
self.model_name_packed,
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config,
|
||||
|
||||
Reference in New Issue
Block a user