From 623ab01039930c173a22832540773873ecaa00c2 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Wed, 23 Jul 2025 11:41:10 +0200 Subject: [PATCH] FP-Quant support (#38696) * quartet * quartet qat -> quartet * format * bf16 backward * interfaces * forward_method * quartet -> fp_quant * style * List -> list * list typing * fixed format and annotations * test_fp_quant * docstrings and default dtypes * better docstring and removed noop checks * docs * pseudoquantization support to test on non-blackwell * pseudoquant * Pseudoquant docs * Update docs/source/en/quantization/fp_quant.md Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update docs/source/en/quantization/fp_quant.md * Update docs/source/en/quantization/fp_quant.md * Update src/transformers/utils/quantization_config.py Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> * Update tests/quantization/fp_quant_integration/test_fp_quant.py Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> * Update tests/quantization/fp_quant_integration/test_fp_quant.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * small test fixes * dockerfile update * spec link * removed `_process_model_after_weight_loading` * toctree --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> --- .../Dockerfile | 3 + docs/source/en/_toctree.yml | 2 + docs/source/en/main_classes/quantization.md | 4 + docs/source/en/quantization/fp_quant.md | 66 +++++ docs/source/en/quantization/overview.md | 1 + src/transformers/__init__.py | 2 + src/transformers/integrations/fp_quant.py | 47 ++++ src/transformers/quantizers/auto.py | 4 + .../quantizers/quantizer_fp_quant.py | 183 ++++++++++++++ src/transformers/testing_utils.py | 16 ++ src/transformers/utils/__init__.py | 2 + src/transformers/utils/import_utils.py | 10 + src/transformers/utils/quantization_config.py | 62 +++++ .../fp_quant_integration/__init__.py | 0 .../fp_quant_integration/test_fp_quant.py | 227 ++++++++++++++++++ 15 files changed, 629 insertions(+) create mode 100644 docs/source/en/quantization/fp_quant.md create mode 100644 src/transformers/integrations/fp_quant.py create mode 100644 src/transformers/quantizers/quantizer_fp_quant.py create mode 100644 tests/quantization/fp_quant_integration/__init__.py create mode 100644 tests/quantization/fp_quant_integration/test_fp_quant.py diff --git a/docker/transformers-quantization-latest-gpu/Dockerfile b/docker/transformers-quantization-latest-gpu/Dockerfile index cfc0478016..e930d6d526 100755 --- a/docker/transformers-quantization-latest-gpu/Dockerfile +++ b/docker/transformers-quantization-latest-gpu/Dockerfile @@ -78,6 +78,9 @@ RUN git clone https://github.com/NetEase-FuXi/EETQ.git && cd EETQ/ && git submod # RUN python3 -m pip install --no-cache-dir flute-kernel==0.4.1 # RUN python3 -m pip install --no-cache-dir git+https://github.com/Dao-AILab/fast-hadamard-transform.git +# Add fp-quant for quantization testing +RUN python3 -m pip install --no-cache-dir "fp-quant>=0.1.6" + # Add compressed-tensors for quantization testing RUN python3 -m pip install --no-cache-dir compressed-tensors diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index cdb89d27cc..c7fc602065 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -179,6 +179,8 @@ title: FBGEMM - local: quantization/finegrained_fp8 title: Fine-grained FP8 + - local: quantization/fp_quant + title: FP-Quant - local: gguf title: GGUF - local: quantization/gptq diff --git a/docs/source/en/main_classes/quantization.md b/docs/source/en/main_classes/quantization.md index 83bc5451bc..992f629e5a 100755 --- a/docs/source/en/main_classes/quantization.md +++ b/docs/source/en/main_classes/quantization.md @@ -93,6 +93,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide. [[autodoc]] QuarkConfig +## FPQuantConfig + +[[autodoc]] FPQuantConfig + ## AutoRoundConfig [[autodoc]] AutoRoundConfig diff --git a/docs/source/en/quantization/fp_quant.md b/docs/source/en/quantization/fp_quant.md new file mode 100644 index 0000000000..a89e35da5c --- /dev/null +++ b/docs/source/en/quantization/fp_quant.md @@ -0,0 +1,66 @@ + + +# FP-Quant + +[FP-Quant](https://github.com/IST-DASLab/FP-Quant) is a family of quantization algorithms tailored for the Blackwell generation of Nvidia GPUs. The goal is to allow for efficient post-training quantization (PTQ) and quantization-aware trainin (QAT) of LLMs in the [MXFP4 and NVFP4 data-types](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). + +Currently, only PTQ with MXFP4 is supported. Models can either be quantized on the fly with `quantization_config=FPQuantConfig()`: + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer, FPQuantConfig +import torch + +model = AutoModelForCausalLM.from_pretrained( + "qwen/Qwen3-8B", + quantization_config=FPQuantConfig(), + device_map="cuda", + torch_dtype=torch.bfloat16, +) +``` + +or pre-processed with GPTQ for better quality (see [FP Format Quantization Harness](https://github.com/IST-DASLab/FP-Quant)). + +A **Blackwell-generation GPU is required** to run the kernels. Runtime support for FP-Quant is implemented through the [QuTLASS](https://github.com/IST-DASLab/qutlass) library and a lightweight PyTorch interface lib [`fp_quant`](https://github.com/IST-DASLab/FP-Quant/tree/master/inference_lib). We recommend installing the former **from source** and the latter with `pip install fp_quant`. + +Users **without a Blackwell-generation GPU** , can use the method with `quantization_config=FPQuantConfig(pseudoquant=True)` without having to install [QuTLASS](https://github.com/IST-DASLab/qutlass). This would provide no speedups but would fully emulate the effect of quantization. + +> [!TIP] +> Find models pre-quantized with FP-Quant in the official ISTA-DASLab [collection](https://huggingface.co/collections/ISTA-DASLab/fp-quant-6877c186103a21d3a02568ee). + +## torch.compile + +FP-Quant is fully compatible with [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, FPQuantConfig + +model = AutoModelForCausalLM.from_pretrained( + "qwen/Qwen3-8B", + quantization_config=FPQuantConfig(), + device_map="cuda", + torch_dtype=torch.bfloat16, +) + +model.forward = torch.compile(model.forward, mode="max-autotune", fullgraph=True) +``` + +## Speedups + +FP-Quant currently performs best for very large batch size processing. + +See [QuTLASS README](https://github.com/IST-DASLab/qutlass/blob/main/README.md) for speedups. \ No newline at end of file diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 9c36bb5976..f551d0690d 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -30,6 +30,7 @@ Use the Space below to help you pick a quantization method depending on your har | [bitsandbytes](./bitsandbytes) | 🟢 | 🟡 | 🟢 | 🟡 | 🔴 | 🟡 | 🟢 | 4/8 | 🟢 | 🟢 | 🟢 | https://github.com/bitsandbytes-foundation/bitsandbytes | | [compressed-tensors](./compressed_tensors) | 🔴 | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 1/8 | 🟢 | 🟢 | 🟢 | https://github.com/neuralmagic/compressed-tensors | | [EETQ](./eetq) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | ? | 8 | 🟢 | 🟢 | 🟢 | https://github.com/NetEase-FuXi/EETQ | +| [FP-Quant](./fp_quant) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 4 | 🔴 | 🟢 | 🟢 | https://github.com/IST-DASLab/FP-Quant | | [GGUF / GGML (llama.cpp)](../gguf) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 1/8 | 🔴 | [See Notes](../gguf) | [See Notes](../gguf) | https://github.com/ggerganov/llama.cpp | | [GPTQModel](./gptq) | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | 🔴 | 2/3/4/8 | 🟢 | 🟢 | 🟢 | https://github.com/ModelCloud/GPTQModel | | [AutoGPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 2/3/4/8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ | diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f6f6fd6f6e..9186277fdd 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -275,6 +275,7 @@ _import_structure = { "HqqConfig", "QuantoConfig", "QuarkConfig", + "FPQuantConfig", "SpQRConfig", "TorchAoConfig", "VptqConfig", @@ -961,6 +962,7 @@ if TYPE_CHECKING: EetqConfig, FbgemmFp8Config, FineGrainedFP8Config, + FPQuantConfig, GPTQConfig, HiggsConfig, HqqConfig, diff --git a/src/transformers/integrations/fp_quant.py b/src/transformers/integrations/fp_quant.py new file mode 100644 index 0000000000..89ebac7004 --- /dev/null +++ b/src/transformers/integrations/fp_quant.py @@ -0,0 +1,47 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"FP-Quant integration file" + +from ..utils import ( + is_fp_quant_available, +) + + +if is_fp_quant_available(): + from fp_quant import FPQuantConfig as FPQuantLinearConfig + from fp_quant import FPQuantDtype + +from transformers.utils.quantization_config import FPQuantConfig + + +def adapt_fp_quant_config(config: FPQuantConfig): + if config.forward_dtype == "mxfp4": + forward_dtype = FPQuantDtype.MXFP4 + else: + raise ValueError(f"Unsupported forward dtype: {config.forward_dtype}") + + if config.backward_dtype == "bf16": + backward_dtype = FPQuantDtype.BF16 + else: + raise ValueError(f"Unsupported backward dtype: {config.backward_dtype}") + + return FPQuantLinearConfig( + forward_dtype=forward_dtype, + forward_method=config.forward_method, + backward_dtype=backward_dtype, + store_master_weights=config.store_master_weights, + hadamard_group_size=config.hadamard_group_size, + pseudoquantization=config.pseudoquantization, + modules_to_not_convert=config.modules_to_not_convert, + ) diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index 7cda893e3c..e4fbaadb5d 100644 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -27,6 +27,7 @@ from ..utils.quantization_config import ( EetqConfig, FbgemmFp8Config, FineGrainedFP8Config, + FPQuantConfig, GPTQConfig, HiggsConfig, HqqConfig, @@ -49,6 +50,7 @@ from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer from .quantizer_eetq import EetqHfQuantizer from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer from .quantizer_finegrained_fp8 import FineGrainedFP8HfQuantizer +from .quantizer_fp_quant import FPQuantHfQuantizer from .quantizer_gptq import GptqHfQuantizer from .quantizer_higgs import HiggsHfQuantizer from .quantizer_hqq import HqqHfQuantizer @@ -67,6 +69,7 @@ AUTO_QUANTIZER_MAPPING = { "aqlm": AqlmHfQuantizer, "quanto": QuantoHfQuantizer, "quark": QuarkHfQuantizer, + "fp_quant": FPQuantHfQuantizer, "eetq": EetqHfQuantizer, "higgs": HiggsHfQuantizer, "hqq": HqqHfQuantizer, @@ -89,6 +92,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = { "aqlm": AqlmConfig, "quanto": QuantoConfig, "quark": QuarkConfig, + "fp_quant": FPQuantConfig, "hqq": HqqConfig, "compressed-tensors": CompressedTensorsConfig, "fbgemm_fp8": FbgemmFp8Config, diff --git a/src/transformers/quantizers/quantizer_fp_quant.py b/src/transformers/quantizers/quantizer_fp_quant.py new file mode 100644 index 0000000000..a94573a912 --- /dev/null +++ b/src/transformers/quantizers/quantizer_fp_quant.py @@ -0,0 +1,183 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, Any, Optional + +from .base import HfQuantizer +from .quantizers_utils import get_module_from_name + + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + +from ..utils import is_fp_quant_available, is_qutlass_available, is_torch_available, logging +from ..utils.quantization_config import QuantizationConfigMixin + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class FPQuantHfQuantizer(HfQuantizer): + """ + Quantizer for the FP-Quant method. Enables the loading of prequantized models and in-flight quantization of full-precision models. + """ + + requires_calibration = False + requires_parameters_quantization = True + is_qat_trainable = False + required_packages = ["fp_quant"] + + def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): + super().__init__(quantization_config, **kwargs) + self.quantization_config = quantization_config + + def validate_environment(self, device_map, **kwargs): + if not torch.cuda.is_available(): + raise NotImplementedError( + "FPQuant quantization is only supported on GPU. Please use a different quantizer." + ) + + if not is_qutlass_available() and not self.quantization_config.pseudoquantization: + raise ImportError( + "Using `fp_quant` with real quantization requires a **Blackwell GPU** and qutlass: `git clone https://github.com/IST-DASLab/qutlass.git && cd qutlass && pip install --no-build-isolation .`. You can use `FPQuantConfig(pseudoquantization=True, ...)` to use Triton-based pseudo-quantization. It doesn't provide any speedups but emulates the quantization behavior of the real quantization." + ) + + if self.quantization_config.pseudoquantization: + logger.warning( + "Using pseudo-quantization for FP-Quant. This doesn't provide any speedups but emulates the quantization behavior of the real quantization." + ) + + if not is_fp_quant_available(): + raise ImportError("Using `fp_quant` quantization requires fp_quant: `pip install fp_quant`") + + if device_map is None: + raise ValueError( + "You are attempting to load a FPQuant model without setting device_map." + " Please set device_map comprised of 'cuda' devices." + ) + elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()): + raise ValueError( + "You are attempting to load a FPQuant model with a device_map that contains a CPU or disk device." + " This is not supported. Please remove the CPU or disk device from the device_map." + ) + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.bfloat16` for qutlass compatibility.") + torch_dtype = torch.bfloat16 + elif torch_dtype != torch.bfloat16: + raise ValueError( + f"Invalid `torch_dtype` {torch_dtype}. fp_quant quantization only supports `torch_dtype=torch.bfloat16`." + ) + + return torch_dtype + + def create_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: dict[str, Any], + unexpected_keys: Optional[list[str]] = None, + ): + module, _ = get_module_from_name(model, param_name) + + # The module holds either: + # * `weight` when `store_master_weights=True` + # * `qweight` and `scales` when `store_master_weights=False` and `pseudoquantization=False` + # * `dqweight` when `store_master_weights=False` and `pseudoquantization=True` + + if param_name.endswith(".qweight"): + # Loading a real quantized checkpoint without master weights + module.qweight = torch.nn.Parameter( + param_value.to(target_device), + requires_grad=False, + ) + module.weight = None + module.dqweight = None + return + + if param_name.endswith(".dqweight"): + # Loading a pseudo-quantized checkpoint without master weights + module.dqweight = torch.nn.Parameter(param_value.to(target_device)) + module.weight = None + module.qweight = None + module.scales = None + return + + # Loading master weights or an unquantized checkpoint + module.weight = torch.nn.Parameter(param_value.to(target_device)) + # Let pre-forward handle the quantization and set None where necessary + module.pre_forward() + + if unexpected_keys is not None and param_name in unexpected_keys: + unexpected_keys.remove(param_name) + + def _process_model_before_weight_loading( + self, + model: "PreTrainedModel", + **kwargs, + ): + from fp_quant import replace_with_fp_quant_linear + + from ..integrations.fp_quant import adapt_fp_quant_config + + replace_with_fp_quant_linear( + model, + fp_quant_linear_config=adapt_fp_quant_config(self.quantization_config), + ) + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + return model + + def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]: + from fp_quant import FPQuantLinear + + fp_quant_names = {name for name, module in model.named_modules() if isinstance(module, FPQuantLinear)} + + def should_exclude(key: str) -> bool: + if key.endswith(".weight") or key.endswith(".bias"): + return False + full_key = f"{prefix}.{key}" + return any(name in key or name in full_key for name in fp_quant_names) + + return [key for key in missing_keys if not should_exclude(key)] + + @property + def is_trainable(self, model: Optional["PreTrainedModel"] = None): + return False + + def is_serializable(self, safe_serialization=None): + return True + + def check_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + state_dict: dict[str, Any], + **kwargs, + ) -> bool: + from fp_quant import FPQuantLinear + + module, tensor_name = get_module_from_name(model, param_name) + if isinstance(module, FPQuantLinear) and tensor_name in ["weight", "qweight", "dqweight"]: + # Only quantize weights of FPQuantLinear modules that are not already quantized + return True + else: + return False diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index d6b425cca6..fd5f62ec28 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -90,6 +90,7 @@ from .utils import ( is_flash_attn_3_available, is_flax_available, is_flute_available, + is_fp_quant_available, is_fsdp_available, is_ftfy_available, is_g2p_en_available, @@ -127,6 +128,7 @@ from .utils import ( is_pytest_available, is_pytorch_quantization_available, is_quark_available, + is_qutlass_available, is_rjieba_available, is_sacremoses_available, is_safetensors_available, @@ -1467,6 +1469,20 @@ def require_flute_hadamard(test_case): )(test_case) +def require_fp_quant(test_case): + """ + Decorator marking a test that requires fp_quant and qutlass + """ + return unittest.skipUnless(is_fp_quant_available(), "test requires fp_quant")(test_case) + + +def require_qutlass(test_case): + """ + Decorator marking a test that requires qutlass + """ + return unittest.skipUnless(is_qutlass_available(), "test requires qutlass")(test_case) + + def require_phonemizer(test_case): """ Decorator marking a test that requires phonemizer diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 1e212b5fa4..0bb3709a42 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -158,6 +158,7 @@ from .import_utils import ( is_flash_attn_greater_or_equal_2_10, is_flax_available, is_flute_available, + is_fp_quant_available, is_fsdp_available, is_ftfy_available, is_g2p_en_available, @@ -204,6 +205,7 @@ from .import_utils import ( is_pytest_available, is_pytorch_quantization_available, is_quark_available, + is_qutlass_available, is_rich_available, is_rjieba_available, is_rocm_platform, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index f56182d2a5..310b00eb73 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -172,6 +172,8 @@ _auto_round_available, _auto_round_version = _is_package_available("auto_round", # `importlib.metadata.version` doesn't work with `awq` _auto_awq_available = importlib.util.find_spec("awq") is not None _quark_available = _is_package_available("quark") +_fp_quant_available, _fp_quant_version = _is_package_available("fp_quant", return_version=True) +_qutlass_available = _is_package_available("qutlass") _is_optimum_quanto_available = False try: importlib.metadata.version("optimum_quanto") @@ -1314,6 +1316,14 @@ def is_quark_available(): return _quark_available +def is_fp_quant_available(): + return _fp_quant_available and version.parse(_fp_quant_version) >= version.parse("0.1.6") + + +def is_qutlass_available(): + return _qutlass_available + + def is_compressed_tensors_available(): return _compressed_tensors_available diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index d6f4dd4663..0bc616c6ff 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -63,6 +63,7 @@ class QuantizationMethod(str, Enum): SPQR = "spqr" FP8 = "fp8" QUARK = "quark" + FPQUANT = "fp_quant" AUTOROUND = "auto-round" @@ -1549,6 +1550,67 @@ class HiggsConfig(QuantizationConfigMixin): raise ValueError("hadamard_size must be divisible by group_size") +@dataclass +class FPQuantConfig(QuantizationConfigMixin): + """ + FPQuantConfig is a configuration class for quantization using the FPQuant method. + + Args: + forward_dtype (`str`, *optional*, defaults to `"mxfp4"`): + The dtype to use for the forward pass. + forward_method (`str`, *optional*, defaults to `"abs_max"`): + The scaling to use for the forward pass. Can be `"abs_max"` or `"quest"`. `"abs_max"` is better for PTQ, `"quest"` is better for QAT. + backward_dtype (`str`, *optional*, defaults to `"bf16"`): + The dtype to use for the backward pass. + store_master_weights (`bool`, *optional*, defaults to `False`): + Whether to store the master weights. Needed for QAT over layer weights. + hadamard_group_size (`int`, *optional*, defaults to 32): + The group size for the hadamard transform before quantization for `"quest"` it matches the MXFP4 group size (32). + pseudoquantization (`bool`, *optional*, defaults to `False`): + Whether to use Triton-based pseudo-quantization. Is mandatory for non-Blackwell GPUs. Doesn't provide any speedup. For debugging purposes. + modules_to_not_convert (`list`, *optional*): + The list of modules to not quantize, useful for quantizing models that explicitly require to have + some modules left in their original precision. + """ + + def __init__( + self, + forward_dtype: str = "mxfp4", + forward_method: str = "abs_max", + backward_dtype: str = "bf16", + store_master_weights: bool = False, + hadamard_group_size: int = 32, + pseudoquantization: bool = False, + modules_to_not_convert: Optional[list[str]] = None, + **kwargs, + ): + self.forward_dtype = forward_dtype + self.forward_method = forward_method + self.backward_dtype = backward_dtype + self.store_master_weights = store_master_weights + self.hadamard_group_size = hadamard_group_size + self.pseudoquantization = pseudoquantization + self.modules_to_not_convert = modules_to_not_convert + + self.quant_method = QuantizationMethod.FPQUANT + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + """ + if self.forward_dtype not in ["mxfp4"]: + raise ValueError("Only 'mxfp4' is supported for forward_dtype for now.") + if self.forward_method not in ["abs_max", "quest"]: + raise ValueError("Only 'abs_max' and 'quest' are supported for forward_method for now.") + if self.backward_dtype not in ["bf16"]: + raise ValueError("Only 'bf16' is supported for backward_dtype for now.") + if self.hadamard_group_size not in [32]: + raise ValueError("Only a hadamard_group_size of 32 is supported for now.") + if self.modules_to_not_convert is None: + self.modules_to_not_convert = ["lm_head"] + + @dataclass class TorchAoConfig(QuantizationConfigMixin): quant_method: QuantizationMethod diff --git a/tests/quantization/fp_quant_integration/__init__.py b/tests/quantization/fp_quant_integration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/quantization/fp_quant_integration/test_fp_quant.py b/tests/quantization/fp_quant_integration/test_fp_quant.py new file mode 100644 index 0000000000..2bb60f5a2d --- /dev/null +++ b/tests/quantization/fp_quant_integration/test_fp_quant.py @@ -0,0 +1,227 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +from transformers import AutoModelForCausalLM, AutoTokenizer, FPQuantConfig +from transformers.testing_utils import ( + backend_empty_cache, + require_accelerate, + require_fp_quant, + require_qutlass, + require_torch_gpu, + require_torch_multi_gpu, + slow, + torch_device, +) + + +@require_torch_gpu +class FPQuantConfigTest(unittest.TestCase): + def test_to_dict(self): + """ + Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object + """ + quantization_config = FPQuantConfig() + config_to_dict = quantization_config.to_dict() + + for key in config_to_dict: + self.assertEqual(getattr(quantization_config, key), config_to_dict[key]) + + def test_from_dict(self): + """ + Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict + """ + dict = {"modules_to_not_convert": ["embed_tokens", "lm_head"], "quant_method": "fp_quant"} + quantization_config = FPQuantConfig.from_dict(dict) + + self.assertEqual(dict["modules_to_not_convert"], quantization_config.modules_to_not_convert) + self.assertEqual(dict["quant_method"], quantization_config.quant_method) + + +@slow +@require_torch_gpu +@require_fp_quant +@require_qutlass +@require_accelerate +class FPQuantTest(unittest.TestCase): + model_name = "unsloth/Llama-3.2-1B" + + input_text = "1 2 3 4" + max_new_tokens = 4 + + EXPECTED_OUTPUT = "1 2 3 4 5 6" + + device_map = "cuda" + + # called only once for all test in this class + @classmethod + def setUpClass(cls): + """ + Setup quantized model + """ + quantization_config = FPQuantConfig(pseudoquantization=False) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.quantized_model = AutoModelForCausalLM.from_pretrained( + cls.model_name, device_map=cls.device_map, quantization_config=quantization_config + ) + + def tearDown(self): + gc.collect() + backend_empty_cache(torch_device) + gc.collect() + + def test_quantized_model(self): + """ + Simple test that checks if the quantized model is working properly + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + def test_save_pretrained(self): + """ + Simple test that checks if the quantized model is working properly after being saved and loaded + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname) + + model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map) + + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + @require_torch_multi_gpu + def test_quantized_model_multi_gpu(self): + """ + Simple test that checks if the quantized model is working properly with multiple GPUs + set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUs + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + quantization_config = FPQuantConfig() + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, device_map="auto", quantization_config=quantization_config + ) + self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + @require_torch_multi_gpu + def test_save_pretrained_multi_gpu(self): + """ + Simple test that checks if the quantized model is working properly after being saved and loaded + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname) + + model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="auto") + self.assertTrue(set(model.hf_device_map.values()) == {0, 1}) + + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + +@slow +@require_torch_gpu +@require_fp_quant +@require_accelerate +class FPQuantPseudoquantTest(unittest.TestCase): + model_name = "unsloth/Llama-3.2-1B" + + input_text = "1 2 3 4" + max_new_tokens = 4 + + EXPECTED_OUTPUT = "1 2 3 4 5 6" + + device_map = "cuda" + + # called only once for all test in this class + @classmethod + def setUpClass(cls): + """ + Setup quantized model + """ + quantization_config = FPQuantConfig(pseudoquantization=True) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.quantized_model = AutoModelForCausalLM.from_pretrained( + cls.model_name, device_map=cls.device_map, quantization_config=quantization_config + ) + + def tearDown(self): + gc.collect() + backend_empty_cache(torch_device) + gc.collect() + + def test_quantized_model(self): + """ + Simple test that checks if the quantized model is working properly + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + def test_save_pretrained(self): + """ + Simple test that checks if the quantized model is working properly after being saved and loaded + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname) + + model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map) + + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + @require_torch_multi_gpu + def test_quantized_model_multi_gpu(self): + """ + Simple test that checks if the quantized model is working properly with multiple GPUs + set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUs + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + quantization_config = FPQuantConfig(pseudoquantization=True) + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, device_map="auto", quantization_config=quantization_config + ) + self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + @require_torch_multi_gpu + def test_save_pretrained_multi_gpu(self): + """ + Simple test that checks if the quantized model is working properly after being saved and loaded + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname) + + model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="auto") + self.assertTrue(set(model.hf_device_map.values()) == {0, 1}) + + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)