Support loading Quark quantized models in Transformers (#36372)
* add quark quantizer * add quark doc * clean up doc * fix tests * make style * more style fixes * cleanup imports * cleaning * precise install * Update docs/source/en/quantization/quark.md Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update tests/quantization/quark_integration/test_quark.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/utils/quantization_config.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * remove import guard as suggested * update copyright headers * add quark to transformers-quantization-latest-gpu Dockerfile * make tests pass on transformers main + quark==0.7 * add missing F8_E4M3 and F8_E5M2 keys from str_to_torch_dtype --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Bowen Bao <bowenbao@amd.com> Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
@@ -79,6 +79,9 @@ RUN git clone https://github.com/NetEase-FuXi/EETQ.git && cd EETQ/ && git submod
|
||||
# Add compressed-tensors for quantization testing
|
||||
RUN python3 -m pip install --no-cache-dir compressed-tensors
|
||||
|
||||
# Add AMD Quark for quantization testing
|
||||
RUN python3 -m pip install --no-cache-dir amd-quark
|
||||
|
||||
# Add transformers in editable mode
|
||||
RUN python3 -m pip install --no-cache-dir -e ./transformers[dev-torch]
|
||||
|
||||
|
||||
@@ -187,6 +187,8 @@
|
||||
title: Optimum
|
||||
- local: quantization/quanto
|
||||
title: Quanto
|
||||
- local: quantization/quark
|
||||
title: Quark
|
||||
- local: quantization/torchao
|
||||
title: torchao
|
||||
- local: quantization/spqr
|
||||
|
||||
@@ -88,3 +88,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
|
||||
## FineGrainedFP8Config
|
||||
|
||||
[[autodoc]] FineGrainedFP8Config
|
||||
|
||||
## QuarkConfig
|
||||
|
||||
[[autodoc]] QuarkConfig
|
||||
|
||||
@@ -40,6 +40,7 @@ Use the Space below to help you pick a quantization method depending on your har
|
||||
| [VPTQ](./vptq) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
|
||||
| [FINEGRAINED_FP8](./finegrained_fp8) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | |
|
||||
| [SpQR](./spqr) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ |
|
||||
| [Quark](./quark.md) | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | ? | 2/4/6/8/9/16 | 🔴 | 🔴 | 🟢 | https://quark.docs.amd.com/latest/ |
|
||||
|
||||
## Resources
|
||||
|
||||
|
||||
84
docs/source/en/quantization/quark.md
Normal file
84
docs/source/en/quantization/quark.md
Normal file
@@ -0,0 +1,84 @@
|
||||
<!--Copyright 2025 Advanced Micro Devices, Inc. and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Quark
|
||||
|
||||
[Quark](https://quark.docs.amd.com/latest/) is a deep learning quantization toolkit designed to be agnostic to specific data types, algorithms, and hardware. Different pre-processing strategies, algorithms and data-types can be combined in Quark.
|
||||
|
||||
The PyTorch support integrated through 🤗 Transformers primarily targets AMD CPUs and GPUs, and is primarily meant to be used for evaluation purposes. For example, it is possible to use [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) with 🤗 Transformers backend and evaluate a wide range of models quantized through Quark seamlessly.
|
||||
|
||||
Users interested in Quark can refer to its [documentation](https://quark.docs.amd.com/latest/) to get started quantizing models and using them in supported open-source libraries!
|
||||
|
||||
Although Quark has its own checkpoint / [configuration format](https://huggingface.co/amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test/blob/main/config.json#L26), the library also supports producing models with a serialization layout compliant with other quantization/runtime implementations ([AutoAWQ](https://huggingface.co/docs/transformers/quantization/awq), [native fp8 in 🤗 Transformers](https://huggingface.co/docs/transformers/quantization/finegrained_fp8)).
|
||||
|
||||
To be able to load Quark quantized models in Transformers, the library first needs to be installed:
|
||||
|
||||
```bash
|
||||
pip install amd-quark
|
||||
```
|
||||
|
||||
## Support matrix
|
||||
|
||||
Models quantized through Quark support a large range of features, that can be combined together. All quantized models independently of their configuration can seamlessly be reloaded through `PretrainedModel.from_pretrained`.
|
||||
|
||||
The table below shows a few features supported by Quark:
|
||||
|
||||
| **Feature** | **Supported subset in Quark** | |
|
||||
|---------------------------------|-----------------------------------------------------------------------------------------------------------|---|
|
||||
| Data types | int8, int4, int2, bfloat16, float16, fp8_e5m2, fp8_e4m3, fp6_e3m2, fp6_e2m3, fp4, OCP MX, MX6, MX9, bfp16 | |
|
||||
| Pre-quantization transformation | SmoothQuant, QuaRot, SpinQuant, AWQ | |
|
||||
| Quantization algorithm | GPTQ | |
|
||||
| Supported operators | ``nn.Linear``, ``nn.Conv2d``, ``nn.ConvTranspose2d``, ``nn.Embedding``, ``nn.EmbeddingBag`` | |
|
||||
| Granularity | per-tensor, per-channel, per-block, per-layer, per-layer type | |
|
||||
| KV cache | fp8 | |
|
||||
| Activation calibration | MinMax / Percentile / MSE | |
|
||||
| Quantization strategy | weight-only, static, dynamic, with or without output quantization | |
|
||||
|
||||
## Models on Hugging Face Hub
|
||||
|
||||
Public models using Quark native serialization can be found at https://huggingface.co/models?other=quark.
|
||||
|
||||
Although Quark also supports [models using `quant_method="fp8"`](https://huggingface.co/models?other=fp8) and [models using `quant_method="awq"`](https://huggingface.co/models?other=awq), Transformers loads these models rather through [AutoAWQ](https://huggingface.co/docs/transformers/quantization/awq) or uses the [native fp8 support in 🤗 Transformers](https://huggingface.co/docs/transformers/quantization/finegrained_fp8).
|
||||
|
||||
## Using Quark models in Transformers
|
||||
|
||||
Here is an example of how one can load a Quark model in Transformers:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_id = "EmbeddedLLM/Llama-3.1-8B-Instruct-w_fp8_per_channel_sym"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
model = model.to("cuda")
|
||||
|
||||
print(model.model.layers[0].self_attn.q_proj)
|
||||
# QParamsLinear(
|
||||
# (weight_quantizer): ScaledRealQuantizer()
|
||||
# (input_quantizer): ScaledRealQuantizer()
|
||||
# (output_quantizer): ScaledRealQuantizer()
|
||||
# )
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inp = tokenizer("Where is a good place to cycle around Tokyo?", return_tensors="pt")
|
||||
inp = inp.to("cuda")
|
||||
|
||||
res = model.generate(**inp, min_new_tokens=50, max_new_tokens=100)
|
||||
|
||||
print(tokenizer.batch_decode(res)[0])
|
||||
# <|begin_of_text|>Where is a good place to cycle around Tokyo? There are several places in Tokyo that are suitable for cycling, depending on your skill level and interests. Here are a few suggestions:
|
||||
# 1. Yoyogi Park: This park is a popular spot for cycling and has a wide, flat path that's perfect for beginners. You can also visit the Meiji Shrine, a famous Shinto shrine located in the park.
|
||||
# 2. Imperial Palace East Garden: This beautiful garden has a large, flat path that's perfect for cycling. You can also visit the
|
||||
```
|
||||
2
src/transformers/__init__.py
Executable file → Normal file
2
src/transformers/__init__.py
Executable file → Normal file
@@ -1046,6 +1046,7 @@ _import_structure = {
|
||||
"HiggsConfig",
|
||||
"HqqConfig",
|
||||
"QuantoConfig",
|
||||
"QuarkConfig",
|
||||
"SpQRConfig",
|
||||
"TorchAoConfig",
|
||||
"VptqConfig",
|
||||
@@ -6287,6 +6288,7 @@ if TYPE_CHECKING:
|
||||
HiggsConfig,
|
||||
HqqConfig,
|
||||
QuantoConfig,
|
||||
QuarkConfig,
|
||||
SpQRConfig,
|
||||
TorchAoConfig,
|
||||
VptqConfig,
|
||||
|
||||
8
src/transformers/modeling_utils.py
Executable file → Normal file
8
src/transformers/modeling_utils.py
Executable file → Normal file
@@ -536,6 +536,10 @@ if is_torch_greater_or_equal("2.3.0"):
|
||||
str_to_torch_dtype["U32"] = torch.uint32
|
||||
str_to_torch_dtype["U64"] = torch.uint64
|
||||
|
||||
if is_torch_greater_or_equal("2.1.0"):
|
||||
str_to_torch_dtype["F8_E4M3"] = torch.float8_e4m3fn
|
||||
str_to_torch_dtype["F8_E5M2"] = torch.float8_e5m2
|
||||
|
||||
|
||||
def load_state_dict(
|
||||
checkpoint_file: Union[str, os.PathLike],
|
||||
@@ -3675,6 +3679,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
|
||||
raise ValueError("`.to` is not supported for HQQ-quantized models.")
|
||||
|
||||
if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK:
|
||||
raise ValueError("Casting a Quark quantized model to a new `dtype` is not supported.")
|
||||
|
||||
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
||||
if dtype_present_in_args:
|
||||
|
||||
5
src/transformers/quantizers/auto.py
Executable file → Normal file
5
src/transformers/quantizers/auto.py
Executable file → Normal file
@@ -1,4 +1,5 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Modifications Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -31,6 +32,7 @@ from ..utils.quantization_config import (
|
||||
QuantizationConfigMixin,
|
||||
QuantizationMethod,
|
||||
QuantoConfig,
|
||||
QuarkConfig,
|
||||
SpQRConfig,
|
||||
TorchAoConfig,
|
||||
VptqConfig,
|
||||
@@ -49,6 +51,7 @@ from .quantizer_gptq import GptqHfQuantizer
|
||||
from .quantizer_higgs import HiggsHfQuantizer
|
||||
from .quantizer_hqq import HqqHfQuantizer
|
||||
from .quantizer_quanto import QuantoHfQuantizer
|
||||
from .quantizer_quark import QuarkHfQuantizer
|
||||
from .quantizer_spqr import SpQRHfQuantizer
|
||||
from .quantizer_torchao import TorchAoHfQuantizer
|
||||
from .quantizer_vptq import VptqHfQuantizer
|
||||
@@ -61,6 +64,7 @@ AUTO_QUANTIZER_MAPPING = {
|
||||
"gptq": GptqHfQuantizer,
|
||||
"aqlm": AqlmHfQuantizer,
|
||||
"quanto": QuantoHfQuantizer,
|
||||
"quark": QuarkHfQuantizer,
|
||||
"eetq": EetqHfQuantizer,
|
||||
"higgs": HiggsHfQuantizer,
|
||||
"hqq": HqqHfQuantizer,
|
||||
@@ -81,6 +85,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
"gptq": GPTQConfig,
|
||||
"aqlm": AqlmConfig,
|
||||
"quanto": QuantoConfig,
|
||||
"quark": QuarkConfig,
|
||||
"hqq": HqqConfig,
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
"fbgemm_fp8": FbgemmFp8Config,
|
||||
|
||||
113
src/transformers/quantizers/quantizer_quark.py
Normal file
113
src/transformers/quantizers/quantizer_quark.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Advanced Micro Devices, Inc. and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
from ..file_utils import is_torch_available
|
||||
from .base import HfQuantizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ..utils import is_accelerate_available, is_quark_available, logging
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
CHECKPOINT_KEYS = {
|
||||
"weight_scale": "weight_quantizer.scale",
|
||||
"bias_scale": "bias_quantizer.scale",
|
||||
"input_scale": "input_quantizer.scale",
|
||||
"output_scale": "output_quantizer.scale",
|
||||
"weight_zero_point": "weight_quantizer.zero_point",
|
||||
"bias_zero_point": "bias_quantizer.zero_point",
|
||||
"input_zero_point": "input_quantizer.zero_point",
|
||||
"output_zero_point": "output_quantizer.zero_point",
|
||||
}
|
||||
|
||||
|
||||
class QuarkHfQuantizer(HfQuantizer):
|
||||
"""
|
||||
Quark quantizer (https://quark.docs.amd.com/latest/).
|
||||
"""
|
||||
|
||||
requires_calibration = True # On-the-fly quantization with quark is not supported for now.
|
||||
required_packages = ["quark"]
|
||||
|
||||
# Checkpoints are expected to be already quantized when loading a quark model. However, as some keys from
|
||||
# the checkpoint might mismatch the model parameters keys, we use the `create_quantized_param` method
|
||||
# to load the checkpoints, remapping the keys.
|
||||
requires_parameters_quantization = True
|
||||
|
||||
def __init__(self, quantization_config, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
|
||||
self.json_export_config = quantization_config.json_export_config
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not is_quark_available():
|
||||
raise ImportError(
|
||||
"Loading a Quark quantized model requires the `quark` library but it was not found in the environment. Please refer to https://quark.docs.amd.com/latest/install.html."
|
||||
)
|
||||
|
||||
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
from quark.torch.export.api import _map_to_quark
|
||||
|
||||
_map_to_quark(
|
||||
model,
|
||||
self.quantization_config.quant_config,
|
||||
pack_method=self.json_export_config.pack_method,
|
||||
custom_mode=self.quantization_config.custom_mode,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def check_quantized_param(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def create_quantized_param(
|
||||
self, model, param, param_name, param_device, state_dict, unexpected_keys
|
||||
) -> "torch.nn.Parameter":
|
||||
postfix = param_name.split(".")[-1]
|
||||
|
||||
if postfix in CHECKPOINT_KEYS:
|
||||
param_name = param_name.replace(postfix, CHECKPOINT_KEYS[postfix])
|
||||
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
return model
|
||||
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_trainable(self):
|
||||
return False
|
||||
@@ -116,6 +116,7 @@ from .utils import (
|
||||
is_pytesseract_available,
|
||||
is_pytest_available,
|
||||
is_pytorch_quantization_available,
|
||||
is_quark_available,
|
||||
is_rjieba_available,
|
||||
is_sacremoses_available,
|
||||
is_safetensors_available,
|
||||
@@ -1299,6 +1300,13 @@ def require_fbgemm_gpu(test_case):
|
||||
return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case)
|
||||
|
||||
|
||||
def require_quark(test_case):
|
||||
"""
|
||||
Decorator for quark dependency
|
||||
"""
|
||||
return unittest.skipUnless(is_quark_available(), "test requires quark")(test_case)
|
||||
|
||||
|
||||
def require_flute_hadamard(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires higgs and hadamard
|
||||
|
||||
1
src/transformers/utils/__init__.py
Executable file → Normal file
1
src/transformers/utils/__init__.py
Executable file → Normal file
@@ -181,6 +181,7 @@ from .import_utils import (
|
||||
is_pytesseract_available,
|
||||
is_pytest_available,
|
||||
is_pytorch_quantization_available,
|
||||
is_quark_available,
|
||||
is_rich_available,
|
||||
is_rjieba_available,
|
||||
is_sacremoses_available,
|
||||
|
||||
16
src/transformers/utils/import_utils.py
Executable file → Normal file
16
src/transformers/utils/import_utils.py
Executable file → Normal file
@@ -45,6 +45,11 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
|
||||
package_version = "N/A"
|
||||
if package_exists:
|
||||
try:
|
||||
# TODO: Once python 3.9 support is dropped, `importlib.metadata.packages_distributions()`
|
||||
# should be used here to map from package name to distribution names
|
||||
# e.g. PIL -> Pillow, Pillow-SIMD; quark -> amd-quark; onnxruntime -> onnxruntime-gpu.
|
||||
# `importlib.metadata.packages_distributions()` is not available in Python 3.9.
|
||||
|
||||
# Primary method to get the package version
|
||||
package_version = importlib.metadata.version(pkg_name)
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
@@ -62,6 +67,12 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
|
||||
except ImportError:
|
||||
# If the package can't be imported, it's not available
|
||||
package_exists = False
|
||||
elif pkg_name == "quark":
|
||||
# TODO: remove once `importlib.metadata.packages_distributions()` is supported.
|
||||
try:
|
||||
package_version = importlib.metadata.version("amd-quark")
|
||||
except Exception:
|
||||
package_exists = False
|
||||
else:
|
||||
# For packages other than "torch", don't attempt the fallback and set as not available
|
||||
package_exists = False
|
||||
@@ -150,6 +161,7 @@ _auto_gptq_available = _is_package_available("auto_gptq")
|
||||
_gptqmodel_available = _is_package_available("gptqmodel")
|
||||
# `importlib.metadata.version` doesn't work with `awq`
|
||||
_auto_awq_available = importlib.util.find_spec("awq") is not None
|
||||
_quark_available = _is_package_available("quark")
|
||||
_is_optimum_quanto_available = False
|
||||
try:
|
||||
importlib.metadata.version("optimum_quanto")
|
||||
@@ -1118,6 +1130,10 @@ def is_optimum_quanto_available():
|
||||
return _is_optimum_quanto_available
|
||||
|
||||
|
||||
def is_quark_available():
|
||||
return _quark_available
|
||||
|
||||
|
||||
def is_compressed_tensors_available():
|
||||
return _compressed_tensors_available
|
||||
|
||||
|
||||
41
src/transformers/utils/quantization_config.py
Executable file → Normal file
41
src/transformers/utils/quantization_config.py
Executable file → Normal file
@@ -2,6 +2,7 @@
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
# Modifications Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -31,6 +32,7 @@ from ..utils import (
|
||||
is_compressed_tensors_available,
|
||||
is_gptqmodel_available,
|
||||
is_hqq_available,
|
||||
is_quark_available,
|
||||
is_torch_available,
|
||||
is_torchao_available,
|
||||
logging,
|
||||
@@ -60,6 +62,7 @@ class QuantizationMethod(str, Enum):
|
||||
BITNET = "bitnet"
|
||||
SPQR = "spqr"
|
||||
FP8 = "fp8"
|
||||
QUARK = "quark"
|
||||
|
||||
|
||||
class AWQLinearVersion(str, Enum):
|
||||
@@ -1772,3 +1775,41 @@ class FineGrainedFP8Config(QuantizationConfigMixin):
|
||||
raise ValueError("weight_block_size must be a tuple of two integers")
|
||||
if self.weight_block_size[0] <= 0 or self.weight_block_size[1] <= 0:
|
||||
raise ValueError("weight_block_size must be a tuple of two positive integers")
|
||||
|
||||
|
||||
class QuarkConfig(QuantizationConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
if is_torch_available() and is_quark_available():
|
||||
from quark import __version__ as quark_version
|
||||
from quark.torch.export.config.config import JsonExporterConfig
|
||||
from quark.torch.export.main_export.quant_config_parser import QuantConfigParser
|
||||
from quark.torch.quantization.config.config import Config
|
||||
|
||||
# This might be e.g. `"fp8"` or `"awq"`.
|
||||
self.custom_mode = kwargs["quant_method"]
|
||||
self.legacy = "export" not in kwargs
|
||||
|
||||
if self.custom_mode in ["awq", "fp8"]:
|
||||
# Legacy (quark<1.0) or custom export.
|
||||
self.quant_config = QuantConfigParser.from_custom_config(kwargs, is_bias_quantized=False)
|
||||
self.json_export_config = JsonExporterConfig()
|
||||
else:
|
||||
self.quant_config = Config.from_dict(kwargs)
|
||||
|
||||
if "export" in kwargs:
|
||||
# TODO: Remove this check once configuration version is handled natively by Quark.
|
||||
if "min_kv_scale" in kwargs["export"] and version.parse(quark_version) < version.parse("0.8"):
|
||||
min_kv_scale = kwargs["export"].pop("min_kv_scale")
|
||||
logger.warning(
|
||||
f"The parameter `min_kv_scale={min_kv_scale}` was found in the model config.json's `quantization_config.export` configuration, but this parameter is supported only for quark>=0.8. Ignoring this configuration parameter. Please update the `amd-quark` package."
|
||||
)
|
||||
|
||||
self.json_export_config = JsonExporterConfig(**kwargs["export"])
|
||||
else:
|
||||
# Legacy (quark<1.0) or custom export.
|
||||
self.json_export_config = JsonExporterConfig()
|
||||
|
||||
self.quant_method = QuantizationMethod.QUARK
|
||||
|
||||
0
tests/quantization/quark_integration/__init__.py
Normal file
0
tests/quantization/quark_integration/__init__.py
Normal file
143
tests/quantization/quark_integration/test_quark.py
Normal file
143
tests/quantization/quark_integration/test_quark.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Advanced Micro Devices, Inc. and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, QuarkConfig
|
||||
from transformers.testing_utils import (
|
||||
is_torch_available,
|
||||
require_accelerate,
|
||||
require_quark,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils.import_utils import is_quark_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_quark_available():
|
||||
from quark.torch.export.nn.modules.qparamslinear import QParamsLinear
|
||||
|
||||
|
||||
class QuarkConfigTest(unittest.TestCase):
|
||||
def test_commmon_args(self):
|
||||
config = AutoConfig.from_pretrained("amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test")
|
||||
QuarkConfig(**config.quantization_config)
|
||||
|
||||
|
||||
@slow
|
||||
@require_quark
|
||||
@require_torch_gpu
|
||||
class QuarkTest(unittest.TestCase):
|
||||
reference_model_name = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
quantized_model_name = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
|
||||
|
||||
input_text = "Today I am in Paris and"
|
||||
|
||||
EXPECTED_OUTPUTS = set()
|
||||
EXPECTED_OUTPUTS.add("Today I am in Paris and I am not in Paris, France\nToday I am in Paris, Illinois")
|
||||
EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying the city of light. I am not just any ordinary Paris")
|
||||
EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying my day off! The sun is shining, the birds are")
|
||||
|
||||
EXPECTED_RELATIVE_DIFFERENCE = 1.66
|
||||
device_map = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
Setup reference & quantized model
|
||||
"""
|
||||
cls.model_fp16 = AutoModelForCausalLM.from_pretrained(
|
||||
cls.reference_model_name, torch_dtype=torch.float16, device_map=cls.device_map
|
||||
)
|
||||
cls.mem_fp16 = cls.model_fp16.get_memory_footprint()
|
||||
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.reference_model_name, use_fast=True)
|
||||
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
cls.quantized_model_name,
|
||||
torch_dtype=torch.float16,
|
||||
device_map=cls.device_map,
|
||||
)
|
||||
|
||||
def test_memory_footprint(self):
|
||||
mem_quantized = self.quantized_model.get_memory_footprint()
|
||||
|
||||
self.assertTrue(self.mem_fp16 / mem_quantized > self.EXPECTED_RELATIVE_DIFFERENCE)
|
||||
|
||||
def test_device_and_dtype_assignment(self):
|
||||
r"""
|
||||
Test whether trying to cast (or assigning a device to) a model after quantization will throw an error.
|
||||
Checks also if other models are casted correctly.
|
||||
"""
|
||||
# This should work
|
||||
if self.device_map is None:
|
||||
_ = self.quantized_model.to(0)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `dtype``
|
||||
self.quantized_model.to(torch.float16)
|
||||
|
||||
def test_original_dtype(self):
|
||||
r"""
|
||||
A simple test to check if the model succesfully stores the original dtype
|
||||
"""
|
||||
self.assertTrue(hasattr(self.quantized_model.config, "_pre_quantization_dtype"))
|
||||
self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype"))
|
||||
self.assertTrue(self.quantized_model.config._pre_quantization_dtype == torch.float16)
|
||||
|
||||
self.assertTrue(isinstance(self.quantized_model.model.layers[0].mlp.gate_proj, QParamsLinear))
|
||||
|
||||
def check_inference_correctness(self, model):
|
||||
r"""
|
||||
Test the generation quality of the quantized model and see that we are matching the expected output.
|
||||
Given that we are operating on small numbers + the testing model is relatively small, we might not get
|
||||
the same output across GPUs. So we'll generate few tokens (5-10) and check their output.
|
||||
"""
|
||||
# Check that inference pass works on the model
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
|
||||
gen_config = GenerationConfig(
|
||||
max_new_tokens=15,
|
||||
min_new_tokens=15,
|
||||
use_cache=True,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
# Check the exactness of the results
|
||||
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), generation_config=gen_config)
|
||||
|
||||
# Get the generation
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
def test_generate_quality(self):
|
||||
"""
|
||||
Simple test to check the quality of the model by comparing the generated tokens with the expected tokens
|
||||
"""
|
||||
if self.device_map is None:
|
||||
self.check_inference_correctness(self.quantized_model.to(0))
|
||||
else:
|
||||
self.check_inference_correctness(self.quantized_model)
|
||||
|
||||
|
||||
@require_accelerate
|
||||
@require_torch_multi_gpu
|
||||
@require_quark
|
||||
class QuarkTestDeviceMap(QuarkTest):
|
||||
device_map = "auto"
|
||||
Reference in New Issue
Block a user