Support loading Quark quantized models in Transformers (#36372)

* add quark quantizer

* add quark doc

* clean up doc

* fix tests

* make style

* more style fixes

* cleanup imports

* cleaning

* precise install

* Update docs/source/en/quantization/quark.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update tests/quantization/quark_integration/test_quark.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/utils/quantization_config.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* remove import guard as suggested

* update copyright headers

* add quark to transformers-quantization-latest-gpu Dockerfile

* make tests pass on transformers main + quark==0.7

* add missing F8_E4M3 and F8_E5M2 keys from str_to_torch_dtype

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Bowen Bao <bowenbao@amd.com>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
fxmarty-amd
2025-03-20 15:40:51 +01:00
committed by GitHub
parent ce091b1bda
commit 1a374799ce
15 changed files with 432 additions and 1 deletions

View File

@@ -79,6 +79,9 @@ RUN git clone https://github.com/NetEase-FuXi/EETQ.git && cd EETQ/ && git submod
# Add compressed-tensors for quantization testing
RUN python3 -m pip install --no-cache-dir compressed-tensors
# Add AMD Quark for quantization testing
RUN python3 -m pip install --no-cache-dir amd-quark
# Add transformers in editable mode
RUN python3 -m pip install --no-cache-dir -e ./transformers[dev-torch]

View File

@@ -187,6 +187,8 @@
title: Optimum
- local: quantization/quanto
title: Quanto
- local: quantization/quark
title: Quark
- local: quantization/torchao
title: torchao
- local: quantization/spqr

View File

@@ -88,3 +88,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
## FineGrainedFP8Config
[[autodoc]] FineGrainedFP8Config
## QuarkConfig
[[autodoc]] QuarkConfig

View File

@@ -40,6 +40,7 @@ Use the Space below to help you pick a quantization method depending on your har
| [VPTQ](./vptq) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
| [FINEGRAINED_FP8](./finegrained_fp8) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | |
| [SpQR](./spqr) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ |
| [Quark](./quark.md) | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | ? | 2/4/6/8/9/16 | 🔴 | 🔴 | 🟢 | https://quark.docs.amd.com/latest/ |
## Resources

View File

@@ -0,0 +1,84 @@
<!--Copyright 2025 Advanced Micro Devices, Inc. and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Quark
[Quark](https://quark.docs.amd.com/latest/) is a deep learning quantization toolkit designed to be agnostic to specific data types, algorithms, and hardware. Different pre-processing strategies, algorithms and data-types can be combined in Quark.
The PyTorch support integrated through 🤗 Transformers primarily targets AMD CPUs and GPUs, and is primarily meant to be used for evaluation purposes. For example, it is possible to use [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) with 🤗 Transformers backend and evaluate a wide range of models quantized through Quark seamlessly.
Users interested in Quark can refer to its [documentation](https://quark.docs.amd.com/latest/) to get started quantizing models and using them in supported open-source libraries!
Although Quark has its own checkpoint / [configuration format](https://huggingface.co/amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test/blob/main/config.json#L26), the library also supports producing models with a serialization layout compliant with other quantization/runtime implementations ([AutoAWQ](https://huggingface.co/docs/transformers/quantization/awq), [native fp8 in 🤗 Transformers](https://huggingface.co/docs/transformers/quantization/finegrained_fp8)).
To be able to load Quark quantized models in Transformers, the library first needs to be installed:
```bash
pip install amd-quark
```
## Support matrix
Models quantized through Quark support a large range of features, that can be combined together. All quantized models independently of their configuration can seamlessly be reloaded through `PretrainedModel.from_pretrained`.
The table below shows a few features supported by Quark:
| **Feature** | **Supported subset in Quark** | |
|---------------------------------|-----------------------------------------------------------------------------------------------------------|---|
| Data types | int8, int4, int2, bfloat16, float16, fp8_e5m2, fp8_e4m3, fp6_e3m2, fp6_e2m3, fp4, OCP MX, MX6, MX9, bfp16 | |
| Pre-quantization transformation | SmoothQuant, QuaRot, SpinQuant, AWQ | |
| Quantization algorithm | GPTQ | |
| Supported operators | ``nn.Linear``, ``nn.Conv2d``, ``nn.ConvTranspose2d``, ``nn.Embedding``, ``nn.EmbeddingBag`` | |
| Granularity | per-tensor, per-channel, per-block, per-layer, per-layer type | |
| KV cache | fp8 | |
| Activation calibration | MinMax / Percentile / MSE | |
| Quantization strategy | weight-only, static, dynamic, with or without output quantization | |
## Models on Hugging Face Hub
Public models using Quark native serialization can be found at https://huggingface.co/models?other=quark.
Although Quark also supports [models using `quant_method="fp8"`](https://huggingface.co/models?other=fp8) and [models using `quant_method="awq"`](https://huggingface.co/models?other=awq), Transformers loads these models rather through [AutoAWQ](https://huggingface.co/docs/transformers/quantization/awq) or uses the [native fp8 support in 🤗 Transformers](https://huggingface.co/docs/transformers/quantization/finegrained_fp8).
## Using Quark models in Transformers
Here is an example of how one can load a Quark model in Transformers:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "EmbeddedLLM/Llama-3.1-8B-Instruct-w_fp8_per_channel_sym"
model = AutoModelForCausalLM.from_pretrained(model_id)
model = model.to("cuda")
print(model.model.layers[0].self_attn.q_proj)
# QParamsLinear(
# (weight_quantizer): ScaledRealQuantizer()
# (input_quantizer): ScaledRealQuantizer()
# (output_quantizer): ScaledRealQuantizer()
# )
tokenizer = AutoTokenizer.from_pretrained(model_id)
inp = tokenizer("Where is a good place to cycle around Tokyo?", return_tensors="pt")
inp = inp.to("cuda")
res = model.generate(**inp, min_new_tokens=50, max_new_tokens=100)
print(tokenizer.batch_decode(res)[0])
# <|begin_of_text|>Where is a good place to cycle around Tokyo? There are several places in Tokyo that are suitable for cycling, depending on your skill level and interests. Here are a few suggestions:
# 1. Yoyogi Park: This park is a popular spot for cycling and has a wide, flat path that's perfect for beginners. You can also visit the Meiji Shrine, a famous Shinto shrine located in the park.
# 2. Imperial Palace East Garden: This beautiful garden has a large, flat path that's perfect for cycling. You can also visit the
```

2
src/transformers/__init__.py Executable file → Normal file
View File

@@ -1046,6 +1046,7 @@ _import_structure = {
"HiggsConfig",
"HqqConfig",
"QuantoConfig",
"QuarkConfig",
"SpQRConfig",
"TorchAoConfig",
"VptqConfig",
@@ -6287,6 +6288,7 @@ if TYPE_CHECKING:
HiggsConfig,
HqqConfig,
QuantoConfig,
QuarkConfig,
SpQRConfig,
TorchAoConfig,
VptqConfig,

8
src/transformers/modeling_utils.py Executable file → Normal file
View File

@@ -536,6 +536,10 @@ if is_torch_greater_or_equal("2.3.0"):
str_to_torch_dtype["U32"] = torch.uint32
str_to_torch_dtype["U64"] = torch.uint64
if is_torch_greater_or_equal("2.1.0"):
str_to_torch_dtype["F8_E4M3"] = torch.float8_e4m3fn
str_to_torch_dtype["F8_E5M2"] = torch.float8_e5m2
def load_state_dict(
checkpoint_file: Union[str, os.PathLike],
@@ -3675,6 +3679,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
raise ValueError("`.to` is not supported for HQQ-quantized models.")
if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK:
raise ValueError("Casting a Quark quantized model to a new `dtype` is not supported.")
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
if dtype_present_in_args:

5
src/transformers/quantizers/auto.py Executable file → Normal file
View File

@@ -1,4 +1,5 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Modifications Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,6 +32,7 @@ from ..utils.quantization_config import (
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
QuarkConfig,
SpQRConfig,
TorchAoConfig,
VptqConfig,
@@ -49,6 +51,7 @@ from .quantizer_gptq import GptqHfQuantizer
from .quantizer_higgs import HiggsHfQuantizer
from .quantizer_hqq import HqqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer
from .quantizer_quark import QuarkHfQuantizer
from .quantizer_spqr import SpQRHfQuantizer
from .quantizer_torchao import TorchAoHfQuantizer
from .quantizer_vptq import VptqHfQuantizer
@@ -61,6 +64,7 @@ AUTO_QUANTIZER_MAPPING = {
"gptq": GptqHfQuantizer,
"aqlm": AqlmHfQuantizer,
"quanto": QuantoHfQuantizer,
"quark": QuarkHfQuantizer,
"eetq": EetqHfQuantizer,
"higgs": HiggsHfQuantizer,
"hqq": HqqHfQuantizer,
@@ -81,6 +85,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"gptq": GPTQConfig,
"aqlm": AqlmConfig,
"quanto": QuantoConfig,
"quark": QuarkConfig,
"hqq": HqqConfig,
"compressed-tensors": CompressedTensorsConfig,
"fbgemm_fp8": FbgemmFp8Config,

View File

@@ -0,0 +1,113 @@
# coding=utf-8
# Copyright 2025 Advanced Micro Devices, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict
from ..file_utils import is_torch_available
from .base import HfQuantizer
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
if is_torch_available():
import torch
from ..utils import is_accelerate_available, is_quark_available, logging
if is_accelerate_available():
from accelerate.utils import set_module_tensor_to_device
logger = logging.get_logger(__name__)
CHECKPOINT_KEYS = {
"weight_scale": "weight_quantizer.scale",
"bias_scale": "bias_quantizer.scale",
"input_scale": "input_quantizer.scale",
"output_scale": "output_quantizer.scale",
"weight_zero_point": "weight_quantizer.zero_point",
"bias_zero_point": "bias_quantizer.zero_point",
"input_zero_point": "input_quantizer.zero_point",
"output_zero_point": "output_quantizer.zero_point",
}
class QuarkHfQuantizer(HfQuantizer):
"""
Quark quantizer (https://quark.docs.amd.com/latest/).
"""
requires_calibration = True # On-the-fly quantization with quark is not supported for now.
required_packages = ["quark"]
# Checkpoints are expected to be already quantized when loading a quark model. However, as some keys from
# the checkpoint might mismatch the model parameters keys, we use the `create_quantized_param` method
# to load the checkpoints, remapping the keys.
requires_parameters_quantization = True
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
self.json_export_config = quantization_config.json_export_config
def validate_environment(self, *args, **kwargs):
if not is_quark_available():
raise ImportError(
"Loading a Quark quantized model requires the `quark` library but it was not found in the environment. Please refer to https://quark.docs.amd.com/latest/install.html."
)
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
from quark.torch.export.api import _map_to_quark
_map_to_quark(
model,
self.quantization_config.quant_config,
pack_method=self.json_export_config.pack_method,
custom_mode=self.quantization_config.custom_mode,
)
return model
def check_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
return True
def create_quantized_param(
self, model, param, param_name, param_device, state_dict, unexpected_keys
) -> "torch.nn.Parameter":
postfix = param_name.split(".")[-1]
if postfix in CHECKPOINT_KEYS:
param_name = param_name.replace(postfix, CHECKPOINT_KEYS[postfix])
set_module_tensor_to_device(model, param_name, param_device, value=param)
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model
def is_serializable(self, safe_serialization=None):
return False
@property
def is_trainable(self):
return False

View File

@@ -116,6 +116,7 @@ from .utils import (
is_pytesseract_available,
is_pytest_available,
is_pytorch_quantization_available,
is_quark_available,
is_rjieba_available,
is_sacremoses_available,
is_safetensors_available,
@@ -1299,6 +1300,13 @@ def require_fbgemm_gpu(test_case):
return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case)
def require_quark(test_case):
"""
Decorator for quark dependency
"""
return unittest.skipUnless(is_quark_available(), "test requires quark")(test_case)
def require_flute_hadamard(test_case):
"""
Decorator marking a test that requires higgs and hadamard

1
src/transformers/utils/__init__.py Executable file → Normal file
View File

@@ -181,6 +181,7 @@ from .import_utils import (
is_pytesseract_available,
is_pytest_available,
is_pytorch_quantization_available,
is_quark_available,
is_rich_available,
is_rjieba_available,
is_sacremoses_available,

16
src/transformers/utils/import_utils.py Executable file → Normal file
View File

@@ -45,6 +45,11 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
package_version = "N/A"
if package_exists:
try:
# TODO: Once python 3.9 support is dropped, `importlib.metadata.packages_distributions()`
# should be used here to map from package name to distribution names
# e.g. PIL -> Pillow, Pillow-SIMD; quark -> amd-quark; onnxruntime -> onnxruntime-gpu.
# `importlib.metadata.packages_distributions()` is not available in Python 3.9.
# Primary method to get the package version
package_version = importlib.metadata.version(pkg_name)
except importlib.metadata.PackageNotFoundError:
@@ -62,6 +67,12 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
except ImportError:
# If the package can't be imported, it's not available
package_exists = False
elif pkg_name == "quark":
# TODO: remove once `importlib.metadata.packages_distributions()` is supported.
try:
package_version = importlib.metadata.version("amd-quark")
except Exception:
package_exists = False
else:
# For packages other than "torch", don't attempt the fallback and set as not available
package_exists = False
@@ -150,6 +161,7 @@ _auto_gptq_available = _is_package_available("auto_gptq")
_gptqmodel_available = _is_package_available("gptqmodel")
# `importlib.metadata.version` doesn't work with `awq`
_auto_awq_available = importlib.util.find_spec("awq") is not None
_quark_available = _is_package_available("quark")
_is_optimum_quanto_available = False
try:
importlib.metadata.version("optimum_quanto")
@@ -1118,6 +1130,10 @@ def is_optimum_quanto_available():
return _is_optimum_quanto_available
def is_quark_available():
return _quark_available
def is_compressed_tensors_available():
return _compressed_tensors_available

41
src/transformers/utils/quantization_config.py Executable file → Normal file
View File

@@ -2,6 +2,7 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
# Modifications Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,6 +32,7 @@ from ..utils import (
is_compressed_tensors_available,
is_gptqmodel_available,
is_hqq_available,
is_quark_available,
is_torch_available,
is_torchao_available,
logging,
@@ -60,6 +62,7 @@ class QuantizationMethod(str, Enum):
BITNET = "bitnet"
SPQR = "spqr"
FP8 = "fp8"
QUARK = "quark"
class AWQLinearVersion(str, Enum):
@@ -1772,3 +1775,41 @@ class FineGrainedFP8Config(QuantizationConfigMixin):
raise ValueError("weight_block_size must be a tuple of two integers")
if self.weight_block_size[0] <= 0 or self.weight_block_size[1] <= 0:
raise ValueError("weight_block_size must be a tuple of two positive integers")
class QuarkConfig(QuantizationConfigMixin):
def __init__(
self,
**kwargs,
):
if is_torch_available() and is_quark_available():
from quark import __version__ as quark_version
from quark.torch.export.config.config import JsonExporterConfig
from quark.torch.export.main_export.quant_config_parser import QuantConfigParser
from quark.torch.quantization.config.config import Config
# This might be e.g. `"fp8"` or `"awq"`.
self.custom_mode = kwargs["quant_method"]
self.legacy = "export" not in kwargs
if self.custom_mode in ["awq", "fp8"]:
# Legacy (quark<1.0) or custom export.
self.quant_config = QuantConfigParser.from_custom_config(kwargs, is_bias_quantized=False)
self.json_export_config = JsonExporterConfig()
else:
self.quant_config = Config.from_dict(kwargs)
if "export" in kwargs:
# TODO: Remove this check once configuration version is handled natively by Quark.
if "min_kv_scale" in kwargs["export"] and version.parse(quark_version) < version.parse("0.8"):
min_kv_scale = kwargs["export"].pop("min_kv_scale")
logger.warning(
f"The parameter `min_kv_scale={min_kv_scale}` was found in the model config.json's `quantization_config.export` configuration, but this parameter is supported only for quark>=0.8. Ignoring this configuration parameter. Please update the `amd-quark` package."
)
self.json_export_config = JsonExporterConfig(**kwargs["export"])
else:
# Legacy (quark<1.0) or custom export.
self.json_export_config = JsonExporterConfig()
self.quant_method = QuantizationMethod.QUARK

View File

@@ -0,0 +1,143 @@
# coding=utf-8
# Copyright 2025 Advanced Micro Devices, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, QuarkConfig
from transformers.testing_utils import (
is_torch_available,
require_accelerate,
require_quark,
require_torch_gpu,
require_torch_multi_gpu,
slow,
)
from transformers.utils.import_utils import is_quark_available
if is_torch_available():
import torch
if is_quark_available():
from quark.torch.export.nn.modules.qparamslinear import QParamsLinear
class QuarkConfigTest(unittest.TestCase):
def test_commmon_args(self):
config = AutoConfig.from_pretrained("amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test")
QuarkConfig(**config.quantization_config)
@slow
@require_quark
@require_torch_gpu
class QuarkTest(unittest.TestCase):
reference_model_name = "meta-llama/Llama-3.1-8B-Instruct"
quantized_model_name = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
input_text = "Today I am in Paris and"
EXPECTED_OUTPUTS = set()
EXPECTED_OUTPUTS.add("Today I am in Paris and I am not in Paris, France\nToday I am in Paris, Illinois")
EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying the city of light. I am not just any ordinary Paris")
EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying my day off! The sun is shining, the birds are")
EXPECTED_RELATIVE_DIFFERENCE = 1.66
device_map = None
@classmethod
def setUpClass(cls):
"""
Setup reference & quantized model
"""
cls.model_fp16 = AutoModelForCausalLM.from_pretrained(
cls.reference_model_name, torch_dtype=torch.float16, device_map=cls.device_map
)
cls.mem_fp16 = cls.model_fp16.get_memory_footprint()
cls.tokenizer = AutoTokenizer.from_pretrained(cls.reference_model_name, use_fast=True)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.quantized_model_name,
torch_dtype=torch.float16,
device_map=cls.device_map,
)
def test_memory_footprint(self):
mem_quantized = self.quantized_model.get_memory_footprint()
self.assertTrue(self.mem_fp16 / mem_quantized > self.EXPECTED_RELATIVE_DIFFERENCE)
def test_device_and_dtype_assignment(self):
r"""
Test whether trying to cast (or assigning a device to) a model after quantization will throw an error.
Checks also if other models are casted correctly.
"""
# This should work
if self.device_map is None:
_ = self.quantized_model.to(0)
with self.assertRaises(ValueError):
# Tries with a `dtype``
self.quantized_model.to(torch.float16)
def test_original_dtype(self):
r"""
A simple test to check if the model succesfully stores the original dtype
"""
self.assertTrue(hasattr(self.quantized_model.config, "_pre_quantization_dtype"))
self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype"))
self.assertTrue(self.quantized_model.config._pre_quantization_dtype == torch.float16)
self.assertTrue(isinstance(self.quantized_model.model.layers[0].mlp.gate_proj, QParamsLinear))
def check_inference_correctness(self, model):
r"""
Test the generation quality of the quantized model and see that we are matching the expected output.
Given that we are operating on small numbers + the testing model is relatively small, we might not get
the same output across GPUs. So we'll generate few tokens (5-10) and check their output.
"""
# Check that inference pass works on the model
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
gen_config = GenerationConfig(
max_new_tokens=15,
min_new_tokens=15,
use_cache=True,
num_beams=1,
do_sample=False,
)
# Check the exactness of the results
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), generation_config=gen_config)
# Get the generation
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_generate_quality(self):
"""
Simple test to check the quality of the model by comparing the generated tokens with the expected tokens
"""
if self.device_map is None:
self.check_inference_correctness(self.quantized_model.to(0))
else:
self.check_inference_correctness(self.quantized_model)
@require_accelerate
@require_torch_multi_gpu
@require_quark
class QuarkTestDeviceMap(QuarkTest):
device_map = "auto"