remove triton_kernels dep with kernels instead (#39926)

* remove dep

* style

* rm import

* fix

* style

* simplify

* style
This commit is contained in:
Marc Sun
2025-08-06 19:31:20 +02:00
committed by GitHub
parent cb2e0df2ec
commit 6902ffa505
6 changed files with 58 additions and 41 deletions

View File

@@ -18,11 +18,11 @@ from unittest.mock import patch
from transformers import AutoTokenizer, GptOssForCausalLM, Mxfp4Config
from transformers.testing_utils import (
require_kernels,
require_torch,
require_torch_gpu,
require_torch_large_gpu,
require_triton,
require_triton_kernels,
slow,
)
from transformers.utils import (
@@ -194,7 +194,7 @@ class Mxfp4QuantizerTest(unittest.TestCase):
"""Test quantizer validation when triton is not available"""
with (
patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False),
patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False),
patch("transformers.quantizers.quantizer_mxfp4.is_kernels_availalble", return_value=False),
):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
@@ -208,7 +208,7 @@ class Mxfp4QuantizerTest(unittest.TestCase):
"""Test quantizer validation when triton is not available but model is pre-quantized and dequantize is False"""
with (
patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False),
patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False),
patch("transformers.quantizers.quantizer_mxfp4.is_kernels_availalble", return_value=False),
):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
@@ -348,7 +348,7 @@ class Mxfp4IntegrationTest(unittest.TestCase):
self.assertEqual(result.dtype, torch.bfloat16)
@require_triton(min_version="3.4.0")
@require_triton_kernels
@require_kernels
@require_torch_gpu
@require_torch
def test_quantize_to_mxfp4(self):
@@ -368,12 +368,14 @@ class Mxfp4IntegrationTest(unittest.TestCase):
@require_torch
@require_torch_large_gpu
@require_triton(min_version="3.4.0")
@require_kernels
@slow
class Mxfp4ModelTest(unittest.TestCase):
"""Test mxfp4 with actual models (requires specific model and hardware)"""
# These should be paths to real OpenAI MoE models for proper testing
model_name_packed = "/fsx/mohamed/oai-hf/tests/20b_converted_packed" # TODO: Use real packed quantized model
model_name = "openai/gpt-oss-20b"
input_text = "Once upon a time"
@@ -421,12 +423,12 @@ class Mxfp4ModelTest(unittest.TestCase):
self.assertFalse(quantization_config.dequantize)
model = GptOssForCausalLM.from_pretrained(
self.model_name_packed,
self.model_name,
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.check_inference_correctness_quantized(model, tokenizer)
def test_gpt_oss_model_loading_dequantized_with_device_map(self):
@@ -438,12 +440,12 @@ class Mxfp4ModelTest(unittest.TestCase):
self.assertTrue(quantization_config.dequantize)
model = GptOssForCausalLM.from_pretrained(
self.model_name_packed,
self.model_name,
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.check_inference_correctness_quantized(model, tokenizer)
def test_model_device_map_validation(self):
@@ -464,12 +466,12 @@ class Mxfp4ModelTest(unittest.TestCase):
# Expected: quantized < dequantized < unquantized memory usage
quantization_config = Mxfp4Config(dequantize=True)
quantized_model = GptOssForCausalLM.from_pretrained(
self.model_name_packed,
self.model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
)
dequantized_model = GptOssForCausalLM.from_pretrained(
self.model_name_packed,
self.model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=quantization_config,