FP-Quant support (#38696)

* quartet

* quartet qat -> quartet

* format

* bf16 backward

* interfaces

* forward_method

* quartet -> fp_quant

* style

* List -> list

* list typing

* fixed format and annotations

* test_fp_quant

* docstrings and default dtypes

* better docstring and removed noop checks

* docs

* pseudoquantization support to test on non-blackwell

* pseudoquant

* Pseudoquant docs

* Update docs/source/en/quantization/fp_quant.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update docs/source/en/quantization/fp_quant.md

* Update docs/source/en/quantization/fp_quant.md

* Update src/transformers/utils/quantization_config.py

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

* Update tests/quantization/fp_quant_integration/test_fp_quant.py

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

* Update tests/quantization/fp_quant_integration/test_fp_quant.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* small test fixes

* dockerfile update

* spec link

* removed `_process_model_after_weight_loading`

* toctree

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
Andrei Panferov
2025-07-23 11:41:10 +02:00
committed by GitHub
parent eb1a007f7f
commit 623ab01039
15 changed files with 629 additions and 0 deletions

View File

@@ -78,6 +78,9 @@ RUN git clone https://github.com/NetEase-FuXi/EETQ.git && cd EETQ/ && git submod
# RUN python3 -m pip install --no-cache-dir flute-kernel==0.4.1 # RUN python3 -m pip install --no-cache-dir flute-kernel==0.4.1
# RUN python3 -m pip install --no-cache-dir git+https://github.com/Dao-AILab/fast-hadamard-transform.git # RUN python3 -m pip install --no-cache-dir git+https://github.com/Dao-AILab/fast-hadamard-transform.git
# Add fp-quant for quantization testing
RUN python3 -m pip install --no-cache-dir "fp-quant>=0.1.6"
# Add compressed-tensors for quantization testing # Add compressed-tensors for quantization testing
RUN python3 -m pip install --no-cache-dir compressed-tensors RUN python3 -m pip install --no-cache-dir compressed-tensors

View File

@@ -179,6 +179,8 @@
title: FBGEMM title: FBGEMM
- local: quantization/finegrained_fp8 - local: quantization/finegrained_fp8
title: Fine-grained FP8 title: Fine-grained FP8
- local: quantization/fp_quant
title: FP-Quant
- local: gguf - local: gguf
title: GGUF title: GGUF
- local: quantization/gptq - local: quantization/gptq

View File

@@ -93,6 +93,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
[[autodoc]] QuarkConfig [[autodoc]] QuarkConfig
## FPQuantConfig
[[autodoc]] FPQuantConfig
## AutoRoundConfig ## AutoRoundConfig
[[autodoc]] AutoRoundConfig [[autodoc]] AutoRoundConfig

View File

@@ -0,0 +1,66 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# FP-Quant
[FP-Quant](https://github.com/IST-DASLab/FP-Quant) is a family of quantization algorithms tailored for the Blackwell generation of Nvidia GPUs. The goal is to allow for efficient post-training quantization (PTQ) and quantization-aware trainin (QAT) of LLMs in the [MXFP4 and NVFP4 data-types](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
Currently, only PTQ with MXFP4 is supported. Models can either be quantized on the fly with `quantization_config=FPQuantConfig()`:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer, FPQuantConfig
import torch
model = AutoModelForCausalLM.from_pretrained(
"qwen/Qwen3-8B",
quantization_config=FPQuantConfig(),
device_map="cuda",
torch_dtype=torch.bfloat16,
)
```
or pre-processed with GPTQ for better quality (see [FP Format Quantization Harness](https://github.com/IST-DASLab/FP-Quant)).
A **Blackwell-generation GPU is required** to run the kernels. Runtime support for FP-Quant is implemented through the [QuTLASS](https://github.com/IST-DASLab/qutlass) library and a lightweight PyTorch interface lib [`fp_quant`](https://github.com/IST-DASLab/FP-Quant/tree/master/inference_lib). We recommend installing the former **from source** and the latter with `pip install fp_quant`.
Users **without a Blackwell-generation GPU** , can use the method with `quantization_config=FPQuantConfig(pseudoquant=True)` without having to install [QuTLASS](https://github.com/IST-DASLab/qutlass). This would provide no speedups but would fully emulate the effect of quantization.
> [!TIP]
> Find models pre-quantized with FP-Quant in the official ISTA-DASLab [collection](https://huggingface.co/collections/ISTA-DASLab/fp-quant-6877c186103a21d3a02568ee).
## torch.compile
FP-Quant is fully compatible with [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html).
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, FPQuantConfig
model = AutoModelForCausalLM.from_pretrained(
"qwen/Qwen3-8B",
quantization_config=FPQuantConfig(),
device_map="cuda",
torch_dtype=torch.bfloat16,
)
model.forward = torch.compile(model.forward, mode="max-autotune", fullgraph=True)
```
## Speedups
FP-Quant currently performs best for very large batch size processing.
See [QuTLASS README](https://github.com/IST-DASLab/qutlass/blob/main/README.md) for speedups.

View File

@@ -30,6 +30,7 @@ Use the Space below to help you pick a quantization method depending on your har
| [bitsandbytes](./bitsandbytes) | 🟢 | 🟡 | 🟢 | 🟡 | 🔴 | 🟡 | 🟢 | 4/8 | 🟢 | 🟢 | 🟢 | https://github.com/bitsandbytes-foundation/bitsandbytes | | [bitsandbytes](./bitsandbytes) | 🟢 | 🟡 | 🟢 | 🟡 | 🔴 | 🟡 | 🟢 | 4/8 | 🟢 | 🟢 | 🟢 | https://github.com/bitsandbytes-foundation/bitsandbytes |
| [compressed-tensors](./compressed_tensors) | 🔴 | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 1/8 | 🟢 | 🟢 | 🟢 | https://github.com/neuralmagic/compressed-tensors | | [compressed-tensors](./compressed_tensors) | 🔴 | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 1/8 | 🟢 | 🟢 | 🟢 | https://github.com/neuralmagic/compressed-tensors |
| [EETQ](./eetq) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | ? | 8 | 🟢 | 🟢 | 🟢 | https://github.com/NetEase-FuXi/EETQ | | [EETQ](./eetq) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | ? | 8 | 🟢 | 🟢 | 🟢 | https://github.com/NetEase-FuXi/EETQ |
| [FP-Quant](./fp_quant) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 4 | 🔴 | 🟢 | 🟢 | https://github.com/IST-DASLab/FP-Quant |
| [GGUF / GGML (llama.cpp)](../gguf) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 1/8 | 🔴 | [See Notes](../gguf) | [See Notes](../gguf) | https://github.com/ggerganov/llama.cpp | | [GGUF / GGML (llama.cpp)](../gguf) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 1/8 | 🔴 | [See Notes](../gguf) | [See Notes](../gguf) | https://github.com/ggerganov/llama.cpp |
| [GPTQModel](./gptq) | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | 🔴 | 2/3/4/8 | 🟢 | 🟢 | 🟢 | https://github.com/ModelCloud/GPTQModel | | [GPTQModel](./gptq) | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | 🔴 | 2/3/4/8 | 🟢 | 🟢 | 🟢 | https://github.com/ModelCloud/GPTQModel |
| [AutoGPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 2/3/4/8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ | | [AutoGPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 2/3/4/8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ |

View File

@@ -275,6 +275,7 @@ _import_structure = {
"HqqConfig", "HqqConfig",
"QuantoConfig", "QuantoConfig",
"QuarkConfig", "QuarkConfig",
"FPQuantConfig",
"SpQRConfig", "SpQRConfig",
"TorchAoConfig", "TorchAoConfig",
"VptqConfig", "VptqConfig",
@@ -961,6 +962,7 @@ if TYPE_CHECKING:
EetqConfig, EetqConfig,
FbgemmFp8Config, FbgemmFp8Config,
FineGrainedFP8Config, FineGrainedFP8Config,
FPQuantConfig,
GPTQConfig, GPTQConfig,
HiggsConfig, HiggsConfig,
HqqConfig, HqqConfig,

View File

@@ -0,0 +1,47 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"FP-Quant integration file"
from ..utils import (
is_fp_quant_available,
)
if is_fp_quant_available():
from fp_quant import FPQuantConfig as FPQuantLinearConfig
from fp_quant import FPQuantDtype
from transformers.utils.quantization_config import FPQuantConfig
def adapt_fp_quant_config(config: FPQuantConfig):
if config.forward_dtype == "mxfp4":
forward_dtype = FPQuantDtype.MXFP4
else:
raise ValueError(f"Unsupported forward dtype: {config.forward_dtype}")
if config.backward_dtype == "bf16":
backward_dtype = FPQuantDtype.BF16
else:
raise ValueError(f"Unsupported backward dtype: {config.backward_dtype}")
return FPQuantLinearConfig(
forward_dtype=forward_dtype,
forward_method=config.forward_method,
backward_dtype=backward_dtype,
store_master_weights=config.store_master_weights,
hadamard_group_size=config.hadamard_group_size,
pseudoquantization=config.pseudoquantization,
modules_to_not_convert=config.modules_to_not_convert,
)

View File

@@ -27,6 +27,7 @@ from ..utils.quantization_config import (
EetqConfig, EetqConfig,
FbgemmFp8Config, FbgemmFp8Config,
FineGrainedFP8Config, FineGrainedFP8Config,
FPQuantConfig,
GPTQConfig, GPTQConfig,
HiggsConfig, HiggsConfig,
HqqConfig, HqqConfig,
@@ -49,6 +50,7 @@ from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer
from .quantizer_eetq import EetqHfQuantizer from .quantizer_eetq import EetqHfQuantizer
from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
from .quantizer_finegrained_fp8 import FineGrainedFP8HfQuantizer from .quantizer_finegrained_fp8 import FineGrainedFP8HfQuantizer
from .quantizer_fp_quant import FPQuantHfQuantizer
from .quantizer_gptq import GptqHfQuantizer from .quantizer_gptq import GptqHfQuantizer
from .quantizer_higgs import HiggsHfQuantizer from .quantizer_higgs import HiggsHfQuantizer
from .quantizer_hqq import HqqHfQuantizer from .quantizer_hqq import HqqHfQuantizer
@@ -67,6 +69,7 @@ AUTO_QUANTIZER_MAPPING = {
"aqlm": AqlmHfQuantizer, "aqlm": AqlmHfQuantizer,
"quanto": QuantoHfQuantizer, "quanto": QuantoHfQuantizer,
"quark": QuarkHfQuantizer, "quark": QuarkHfQuantizer,
"fp_quant": FPQuantHfQuantizer,
"eetq": EetqHfQuantizer, "eetq": EetqHfQuantizer,
"higgs": HiggsHfQuantizer, "higgs": HiggsHfQuantizer,
"hqq": HqqHfQuantizer, "hqq": HqqHfQuantizer,
@@ -89,6 +92,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"aqlm": AqlmConfig, "aqlm": AqlmConfig,
"quanto": QuantoConfig, "quanto": QuantoConfig,
"quark": QuarkConfig, "quark": QuarkConfig,
"fp_quant": FPQuantConfig,
"hqq": HqqConfig, "hqq": HqqConfig,
"compressed-tensors": CompressedTensorsConfig, "compressed-tensors": CompressedTensorsConfig,
"fbgemm_fp8": FbgemmFp8Config, "fbgemm_fp8": FbgemmFp8Config,

View File

@@ -0,0 +1,183 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Optional
from .base import HfQuantizer
from .quantizers_utils import get_module_from_name
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from ..utils import is_fp_quant_available, is_qutlass_available, is_torch_available, logging
from ..utils.quantization_config import QuantizationConfigMixin
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class FPQuantHfQuantizer(HfQuantizer):
"""
Quantizer for the FP-Quant method. Enables the loading of prequantized models and in-flight quantization of full-precision models.
"""
requires_calibration = False
requires_parameters_quantization = True
is_qat_trainable = False
required_packages = ["fp_quant"]
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
self.quantization_config = quantization_config
def validate_environment(self, device_map, **kwargs):
if not torch.cuda.is_available():
raise NotImplementedError(
"FPQuant quantization is only supported on GPU. Please use a different quantizer."
)
if not is_qutlass_available() and not self.quantization_config.pseudoquantization:
raise ImportError(
"Using `fp_quant` with real quantization requires a **Blackwell GPU** and qutlass: `git clone https://github.com/IST-DASLab/qutlass.git && cd qutlass && pip install --no-build-isolation .`. You can use `FPQuantConfig(pseudoquantization=True, ...)` to use Triton-based pseudo-quantization. It doesn't provide any speedups but emulates the quantization behavior of the real quantization."
)
if self.quantization_config.pseudoquantization:
logger.warning(
"Using pseudo-quantization for FP-Quant. This doesn't provide any speedups but emulates the quantization behavior of the real quantization."
)
if not is_fp_quant_available():
raise ImportError("Using `fp_quant` quantization requires fp_quant: `pip install fp_quant`")
if device_map is None:
raise ValueError(
"You are attempting to load a FPQuant model without setting device_map."
" Please set device_map comprised of 'cuda' devices."
)
elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
raise ValueError(
"You are attempting to load a FPQuant model with a device_map that contains a CPU or disk device."
" This is not supported. Please remove the CPU or disk device from the device_map."
)
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.bfloat16` for qutlass compatibility.")
torch_dtype = torch.bfloat16
elif torch_dtype != torch.bfloat16:
raise ValueError(
f"Invalid `torch_dtype` {torch_dtype}. fp_quant quantization only supports `torch_dtype=torch.bfloat16`."
)
return torch_dtype
def create_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: dict[str, Any],
unexpected_keys: Optional[list[str]] = None,
):
module, _ = get_module_from_name(model, param_name)
# The module holds either:
# * `weight` when `store_master_weights=True`
# * `qweight` and `scales` when `store_master_weights=False` and `pseudoquantization=False`
# * `dqweight` when `store_master_weights=False` and `pseudoquantization=True`
if param_name.endswith(".qweight"):
# Loading a real quantized checkpoint without master weights
module.qweight = torch.nn.Parameter(
param_value.to(target_device),
requires_grad=False,
)
module.weight = None
module.dqweight = None
return
if param_name.endswith(".dqweight"):
# Loading a pseudo-quantized checkpoint without master weights
module.dqweight = torch.nn.Parameter(param_value.to(target_device))
module.weight = None
module.qweight = None
module.scales = None
return
# Loading master weights or an unquantized checkpoint
module.weight = torch.nn.Parameter(param_value.to(target_device))
# Let pre-forward handle the quantization and set None where necessary
module.pre_forward()
if unexpected_keys is not None and param_name in unexpected_keys:
unexpected_keys.remove(param_name)
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
**kwargs,
):
from fp_quant import replace_with_fp_quant_linear
from ..integrations.fp_quant import adapt_fp_quant_config
replace_with_fp_quant_linear(
model,
fp_quant_linear_config=adapt_fp_quant_config(self.quantization_config),
)
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
from fp_quant import FPQuantLinear
fp_quant_names = {name for name, module in model.named_modules() if isinstance(module, FPQuantLinear)}
def should_exclude(key: str) -> bool:
if key.endswith(".weight") or key.endswith(".bias"):
return False
full_key = f"{prefix}.{key}"
return any(name in key or name in full_key for name in fp_quant_names)
return [key for key in missing_keys if not should_exclude(key)]
@property
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
return False
def is_serializable(self, safe_serialization=None):
return True
def check_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: dict[str, Any],
**kwargs,
) -> bool:
from fp_quant import FPQuantLinear
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, FPQuantLinear) and tensor_name in ["weight", "qweight", "dqweight"]:
# Only quantize weights of FPQuantLinear modules that are not already quantized
return True
else:
return False

View File

@@ -90,6 +90,7 @@ from .utils import (
is_flash_attn_3_available, is_flash_attn_3_available,
is_flax_available, is_flax_available,
is_flute_available, is_flute_available,
is_fp_quant_available,
is_fsdp_available, is_fsdp_available,
is_ftfy_available, is_ftfy_available,
is_g2p_en_available, is_g2p_en_available,
@@ -127,6 +128,7 @@ from .utils import (
is_pytest_available, is_pytest_available,
is_pytorch_quantization_available, is_pytorch_quantization_available,
is_quark_available, is_quark_available,
is_qutlass_available,
is_rjieba_available, is_rjieba_available,
is_sacremoses_available, is_sacremoses_available,
is_safetensors_available, is_safetensors_available,
@@ -1467,6 +1469,20 @@ def require_flute_hadamard(test_case):
)(test_case) )(test_case)
def require_fp_quant(test_case):
"""
Decorator marking a test that requires fp_quant and qutlass
"""
return unittest.skipUnless(is_fp_quant_available(), "test requires fp_quant")(test_case)
def require_qutlass(test_case):
"""
Decorator marking a test that requires qutlass
"""
return unittest.skipUnless(is_qutlass_available(), "test requires qutlass")(test_case)
def require_phonemizer(test_case): def require_phonemizer(test_case):
""" """
Decorator marking a test that requires phonemizer Decorator marking a test that requires phonemizer

View File

@@ -158,6 +158,7 @@ from .import_utils import (
is_flash_attn_greater_or_equal_2_10, is_flash_attn_greater_or_equal_2_10,
is_flax_available, is_flax_available,
is_flute_available, is_flute_available,
is_fp_quant_available,
is_fsdp_available, is_fsdp_available,
is_ftfy_available, is_ftfy_available,
is_g2p_en_available, is_g2p_en_available,
@@ -204,6 +205,7 @@ from .import_utils import (
is_pytest_available, is_pytest_available,
is_pytorch_quantization_available, is_pytorch_quantization_available,
is_quark_available, is_quark_available,
is_qutlass_available,
is_rich_available, is_rich_available,
is_rjieba_available, is_rjieba_available,
is_rocm_platform, is_rocm_platform,

View File

@@ -172,6 +172,8 @@ _auto_round_available, _auto_round_version = _is_package_available("auto_round",
# `importlib.metadata.version` doesn't work with `awq` # `importlib.metadata.version` doesn't work with `awq`
_auto_awq_available = importlib.util.find_spec("awq") is not None _auto_awq_available = importlib.util.find_spec("awq") is not None
_quark_available = _is_package_available("quark") _quark_available = _is_package_available("quark")
_fp_quant_available, _fp_quant_version = _is_package_available("fp_quant", return_version=True)
_qutlass_available = _is_package_available("qutlass")
_is_optimum_quanto_available = False _is_optimum_quanto_available = False
try: try:
importlib.metadata.version("optimum_quanto") importlib.metadata.version("optimum_quanto")
@@ -1314,6 +1316,14 @@ def is_quark_available():
return _quark_available return _quark_available
def is_fp_quant_available():
return _fp_quant_available and version.parse(_fp_quant_version) >= version.parse("0.1.6")
def is_qutlass_available():
return _qutlass_available
def is_compressed_tensors_available(): def is_compressed_tensors_available():
return _compressed_tensors_available return _compressed_tensors_available

View File

@@ -63,6 +63,7 @@ class QuantizationMethod(str, Enum):
SPQR = "spqr" SPQR = "spqr"
FP8 = "fp8" FP8 = "fp8"
QUARK = "quark" QUARK = "quark"
FPQUANT = "fp_quant"
AUTOROUND = "auto-round" AUTOROUND = "auto-round"
@@ -1549,6 +1550,67 @@ class HiggsConfig(QuantizationConfigMixin):
raise ValueError("hadamard_size must be divisible by group_size") raise ValueError("hadamard_size must be divisible by group_size")
@dataclass
class FPQuantConfig(QuantizationConfigMixin):
"""
FPQuantConfig is a configuration class for quantization using the FPQuant method.
Args:
forward_dtype (`str`, *optional*, defaults to `"mxfp4"`):
The dtype to use for the forward pass.
forward_method (`str`, *optional*, defaults to `"abs_max"`):
The scaling to use for the forward pass. Can be `"abs_max"` or `"quest"`. `"abs_max"` is better for PTQ, `"quest"` is better for QAT.
backward_dtype (`str`, *optional*, defaults to `"bf16"`):
The dtype to use for the backward pass.
store_master_weights (`bool`, *optional*, defaults to `False`):
Whether to store the master weights. Needed for QAT over layer weights.
hadamard_group_size (`int`, *optional*, defaults to 32):
The group size for the hadamard transform before quantization for `"quest"` it matches the MXFP4 group size (32).
pseudoquantization (`bool`, *optional*, defaults to `False`):
Whether to use Triton-based pseudo-quantization. Is mandatory for non-Blackwell GPUs. Doesn't provide any speedup. For debugging purposes.
modules_to_not_convert (`list`, *optional*):
The list of modules to not quantize, useful for quantizing models that explicitly require to have
some modules left in their original precision.
"""
def __init__(
self,
forward_dtype: str = "mxfp4",
forward_method: str = "abs_max",
backward_dtype: str = "bf16",
store_master_weights: bool = False,
hadamard_group_size: int = 32,
pseudoquantization: bool = False,
modules_to_not_convert: Optional[list[str]] = None,
**kwargs,
):
self.forward_dtype = forward_dtype
self.forward_method = forward_method
self.backward_dtype = backward_dtype
self.store_master_weights = store_master_weights
self.hadamard_group_size = hadamard_group_size
self.pseudoquantization = pseudoquantization
self.modules_to_not_convert = modules_to_not_convert
self.quant_method = QuantizationMethod.FPQUANT
self.post_init()
def post_init(self):
r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
if self.forward_dtype not in ["mxfp4"]:
raise ValueError("Only 'mxfp4' is supported for forward_dtype for now.")
if self.forward_method not in ["abs_max", "quest"]:
raise ValueError("Only 'abs_max' and 'quest' are supported for forward_method for now.")
if self.backward_dtype not in ["bf16"]:
raise ValueError("Only 'bf16' is supported for backward_dtype for now.")
if self.hadamard_group_size not in [32]:
raise ValueError("Only a hadamard_group_size of 32 is supported for now.")
if self.modules_to_not_convert is None:
self.modules_to_not_convert = ["lm_head"]
@dataclass @dataclass
class TorchAoConfig(QuantizationConfigMixin): class TorchAoConfig(QuantizationConfigMixin):
quant_method: QuantizationMethod quant_method: QuantizationMethod

View File

@@ -0,0 +1,227 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
from transformers import AutoModelForCausalLM, AutoTokenizer, FPQuantConfig
from transformers.testing_utils import (
backend_empty_cache,
require_accelerate,
require_fp_quant,
require_qutlass,
require_torch_gpu,
require_torch_multi_gpu,
slow,
torch_device,
)
@require_torch_gpu
class FPQuantConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object
"""
quantization_config = FPQuantConfig()
config_to_dict = quantization_config.to_dict()
for key in config_to_dict:
self.assertEqual(getattr(quantization_config, key), config_to_dict[key])
def test_from_dict(self):
"""
Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict
"""
dict = {"modules_to_not_convert": ["embed_tokens", "lm_head"], "quant_method": "fp_quant"}
quantization_config = FPQuantConfig.from_dict(dict)
self.assertEqual(dict["modules_to_not_convert"], quantization_config.modules_to_not_convert)
self.assertEqual(dict["quant_method"], quantization_config.quant_method)
@slow
@require_torch_gpu
@require_fp_quant
@require_qutlass
@require_accelerate
class FPQuantTest(unittest.TestCase):
model_name = "unsloth/Llama-3.2-1B"
input_text = "1 2 3 4"
max_new_tokens = 4
EXPECTED_OUTPUT = "1 2 3 4 5 6"
device_map = "cuda"
# called only once for all test in this class
@classmethod
def setUpClass(cls):
"""
Setup quantized model
"""
quantization_config = FPQuantConfig(pseudoquantization=False)
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name, device_map=cls.device_map, quantization_config=quantization_config
)
def tearDown(self):
gc.collect()
backend_empty_cache(torch_device)
gc.collect()
def test_quantized_model(self):
"""
Simple test that checks if the quantized model is working properly
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_save_pretrained(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_gpu
def test_quantized_model_multi_gpu(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUs
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
quantization_config = FPQuantConfig()
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, device_map="auto", quantization_config=quantization_config
)
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_gpu
def test_save_pretrained_multi_gpu(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="auto")
self.assertTrue(set(model.hf_device_map.values()) == {0, 1})
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@slow
@require_torch_gpu
@require_fp_quant
@require_accelerate
class FPQuantPseudoquantTest(unittest.TestCase):
model_name = "unsloth/Llama-3.2-1B"
input_text = "1 2 3 4"
max_new_tokens = 4
EXPECTED_OUTPUT = "1 2 3 4 5 6"
device_map = "cuda"
# called only once for all test in this class
@classmethod
def setUpClass(cls):
"""
Setup quantized model
"""
quantization_config = FPQuantConfig(pseudoquantization=True)
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name, device_map=cls.device_map, quantization_config=quantization_config
)
def tearDown(self):
gc.collect()
backend_empty_cache(torch_device)
gc.collect()
def test_quantized_model(self):
"""
Simple test that checks if the quantized model is working properly
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_save_pretrained(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_gpu
def test_quantized_model_multi_gpu(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUs
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
quantization_config = FPQuantConfig(pseudoquantization=True)
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, device_map="auto", quantization_config=quantization_config
)
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_gpu
def test_save_pretrained_multi_gpu(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="auto")
self.assertTrue(set(model.hf_device_map.values()) == {0, 1})
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)