remove triton_kernels dep with kernels instead (#39926)
* remove dep * style * rm import * fix * style * simplify * style
This commit is contained in:
@@ -18,11 +18,11 @@ from unittest.mock import patch
|
||||
|
||||
from transformers import AutoTokenizer, GptOssForCausalLM, Mxfp4Config
|
||||
from transformers.testing_utils import (
|
||||
require_kernels,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_large_gpu,
|
||||
require_triton,
|
||||
require_triton_kernels,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils import (
|
||||
@@ -194,7 +194,7 @@ class Mxfp4QuantizerTest(unittest.TestCase):
|
||||
"""Test quantizer validation when triton is not available"""
|
||||
with (
|
||||
patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False),
|
||||
patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False),
|
||||
patch("transformers.quantizers.quantizer_mxfp4.is_kernels_availalble", return_value=False),
|
||||
):
|
||||
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
|
||||
|
||||
@@ -208,7 +208,7 @@ class Mxfp4QuantizerTest(unittest.TestCase):
|
||||
"""Test quantizer validation when triton is not available but model is pre-quantized and dequantize is False"""
|
||||
with (
|
||||
patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False),
|
||||
patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False),
|
||||
patch("transformers.quantizers.quantizer_mxfp4.is_kernels_availalble", return_value=False),
|
||||
):
|
||||
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
|
||||
|
||||
@@ -348,7 +348,7 @@ class Mxfp4IntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@require_triton(min_version="3.4.0")
|
||||
@require_triton_kernels
|
||||
@require_kernels
|
||||
@require_torch_gpu
|
||||
@require_torch
|
||||
def test_quantize_to_mxfp4(self):
|
||||
@@ -368,12 +368,14 @@ class Mxfp4IntegrationTest(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
@require_torch_large_gpu
|
||||
@require_triton(min_version="3.4.0")
|
||||
@require_kernels
|
||||
@slow
|
||||
class Mxfp4ModelTest(unittest.TestCase):
|
||||
"""Test mxfp4 with actual models (requires specific model and hardware)"""
|
||||
|
||||
# These should be paths to real OpenAI MoE models for proper testing
|
||||
model_name_packed = "/fsx/mohamed/oai-hf/tests/20b_converted_packed" # TODO: Use real packed quantized model
|
||||
model_name = "openai/gpt-oss-20b"
|
||||
|
||||
input_text = "Once upon a time"
|
||||
|
||||
@@ -421,12 +423,12 @@ class Mxfp4ModelTest(unittest.TestCase):
|
||||
self.assertFalse(quantization_config.dequantize)
|
||||
|
||||
model = GptOssForCausalLM.from_pretrained(
|
||||
self.model_name_packed,
|
||||
self.model_name,
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
self.check_inference_correctness_quantized(model, tokenizer)
|
||||
|
||||
def test_gpt_oss_model_loading_dequantized_with_device_map(self):
|
||||
@@ -438,12 +440,12 @@ class Mxfp4ModelTest(unittest.TestCase):
|
||||
self.assertTrue(quantization_config.dequantize)
|
||||
|
||||
model = GptOssForCausalLM.from_pretrained(
|
||||
self.model_name_packed,
|
||||
self.model_name,
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
self.check_inference_correctness_quantized(model, tokenizer)
|
||||
|
||||
def test_model_device_map_validation(self):
|
||||
@@ -464,12 +466,12 @@ class Mxfp4ModelTest(unittest.TestCase):
|
||||
# Expected: quantized < dequantized < unquantized memory usage
|
||||
quantization_config = Mxfp4Config(dequantize=True)
|
||||
quantized_model = GptOssForCausalLM.from_pretrained(
|
||||
self.model_name_packed,
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
dequantized_model = GptOssForCausalLM.from_pretrained(
|
||||
self.model_name_packed,
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config,
|
||||
|
||||
Reference in New Issue
Block a user