Add TorchAOHfQuantizer (#32306)

* Add TorchAOHfQuantizer

Summary:
Enable loading torchao quantized model in huggingface.

Test Plan:
local test

Reviewers:

Subscribers:

Tasks:

Tags:

* Fix a few issues

* style

* Added tests and addressed some comments about dtype conversion

* fix torch_dtype warning message

* fix tests

* style

* TorchAOConfig -> TorchAoConfig

* enable offload + fix memory with multi-gpu

* update torchao version requirement to 0.4.0

* better comments

* add torch.compile to torchao README, add perf number link

---------

Co-authored-by: Marc Sun <marc@huggingface.co>
This commit is contained in:
Jerry Zhang
2024-08-14 07:14:24 -07:00
committed by GitHub
parent 9485289f37
commit 78d78cdf8a
13 changed files with 539 additions and 3 deletions

View File

@@ -163,6 +163,8 @@
title: FBGEMM_FP8 title: FBGEMM_FP8
- local: quantization/optimum - local: quantization/optimum
title: Optimum title: Optimum
- local: quantization/torchao
title: TorchAO
- local: quantization/contribute - local: quantization/contribute
title: Contribute new quantization method title: Contribute new quantization method
title: Quantization Methods title: Quantization Methods

View File

@@ -61,3 +61,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
[[autodoc]] FbgemmFp8Config [[autodoc]] FbgemmFp8Config
## TorchAoConfig
[[autodoc]] TorchAoConfig

View File

@@ -56,4 +56,4 @@ Use the table below to help you decide which quantization method to use.
| [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ | | [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
| [Quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/quanto | | [Quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/quanto |
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM | | [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | partial support (int4 weight only) | | 4 / 8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |

View File

@@ -0,0 +1,45 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# TorchAO
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#without-intrusive-code-changes)
Before you begin, make sure the following libraries are installed with their latest version:
```bash
pip install --upgrade torch torchao
```
```py
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
model_name = "meta-llama/Meta-Llama-3-8B"
# We support int4_weight_only, int8_weight_only and int8_dynamic_activation_int8_weight
# More examples and documentations for arguments can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
# compile the quantizd model to get speedup
import torchao
torchao.quantization.utils.recommended_inductor_config_setter()
quantized_model = torch.compile(quantized_model, mode="max-autotune")
output = quantized_model.generate(**input_ids, max_new_tokens=10)
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
torchao quantization is implemented with tensor subclasses, currently it does not work with huggingface serialization, both the safetensor option and [non-safetensor option](https://github.com/huggingface/transformers/issues/32364), we'll update here with instructions when it's working.

View File

@@ -949,6 +949,7 @@ _import_structure = {
"GPTQConfig", "GPTQConfig",
"HqqConfig", "HqqConfig",
"QuantoConfig", "QuantoConfig",
"TorchAoConfig",
], ],
} }
@@ -5728,6 +5729,7 @@ if TYPE_CHECKING:
GPTQConfig, GPTQConfig,
HqqConfig, HqqConfig,
QuantoConfig, QuantoConfig,
TorchAoConfig,
) )
try: try:

View File

@@ -26,6 +26,7 @@ from ..utils.quantization_config import (
QuantizationConfigMixin, QuantizationConfigMixin,
QuantizationMethod, QuantizationMethod,
QuantoConfig, QuantoConfig,
TorchAoConfig,
) )
from .quantizer_aqlm import AqlmHfQuantizer from .quantizer_aqlm import AqlmHfQuantizer
from .quantizer_awq import AwqQuantizer from .quantizer_awq import AwqQuantizer
@@ -36,6 +37,7 @@ from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
from .quantizer_gptq import GptqHfQuantizer from .quantizer_gptq import GptqHfQuantizer
from .quantizer_hqq import HqqHfQuantizer from .quantizer_hqq import HqqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer from .quantizer_quanto import QuantoHfQuantizer
from .quantizer_torchao import TorchAoHfQuantizer
AUTO_QUANTIZER_MAPPING = { AUTO_QUANTIZER_MAPPING = {
@@ -48,6 +50,7 @@ AUTO_QUANTIZER_MAPPING = {
"eetq": EetqHfQuantizer, "eetq": EetqHfQuantizer,
"hqq": HqqHfQuantizer, "hqq": HqqHfQuantizer,
"fbgemm_fp8": FbgemmFp8HfQuantizer, "fbgemm_fp8": FbgemmFp8HfQuantizer,
"torchao": TorchAoHfQuantizer,
} }
AUTO_QUANTIZATION_CONFIG_MAPPING = { AUTO_QUANTIZATION_CONFIG_MAPPING = {
@@ -60,6 +63,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"quanto": QuantoConfig, "quanto": QuantoConfig,
"hqq": HqqConfig, "hqq": HqqConfig,
"fbgemm_fp8": FbgemmFp8Config, "fbgemm_fp8": FbgemmFp8Config,
"torchao": TorchAoConfig,
} }

View File

@@ -0,0 +1,172 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import TYPE_CHECKING, Union
from packaging import version
from .base import HfQuantizer
from .quantizers_utils import get_module_from_name
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from typing import Any, Dict, List
from ..utils import is_torch_available, is_torchao_available, logging
if is_torch_available():
import torch
if is_torchao_available():
from torchao.quantization import quantize_
logger = logging.get_logger(__name__)
# Finds the parent of a node module named "name"
def find_parent(model, name):
module_tree = name.split(".")[:-1]
parent = model
for m in module_tree:
parent = parent._modules[m]
return parent
class TorchAoHfQuantizer(HfQuantizer):
"""
Quantizer for torchao: https://github.com/pytorch/ao/
"""
requires_parameters_quantization = True
requires_calibration = False
required_packages = ["torchao"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
def validate_environment(self, *args, **kwargs):
if not is_torchao_available():
raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)")
self.offload = False
device_map = kwargs.get("device_map", None)
if isinstance(device_map, dict):
if "cpu" in device_map.values() or "disk" in device_map.values():
if self.pre_quantized:
raise ValueError(
"You are attempting to perform cpu/disk offload with a pre-quantized torchao model "
"This is not supported yet . Please remove the CPU or disk device from the device_map."
)
else:
self.offload = True
def update_torch_dtype(self, torch_dtype):
if self.quantization_config.quant_type == "int4_weight_only":
if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning_once(
f"Setting torch_dtype to {torch_dtype} for int4_weight_only quantization, but only bfloat16 is supported right now. Please set the torch_dtype to bfloat16."
)
if torch_dtype is None:
logger.warning_once(
"Setting torch_dtype to torch.bfloat16 for int4_weight_only quantization since only bfloat16 is supported right now. Please set torch_dtype=torch.bfloat16 to remove this warning."
)
torch_dtype = torch.bfloat16
return torch_dtype
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
from accelerate.utils import CustomDtype
map_to_target_dtype = {
"int4_weight_only": CustomDtype.INT4,
"int8_weight_only": torch.int8,
"int8_dynamic_activation_int8_weight": torch.int8,
}
return map_to_target_dtype[self.quantization_config.quant_type]
else:
raise ValueError(
"You are using `device_map='auto'` on a torchao quantized model. To automatically compute"
" the appropriate device map, you should upgrade your `accelerate` library with "
"`pip install --upgrade accelerate`"
)
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
# need more space for the quantization parameters (e.g. scale). Tested with int4 wo and group size = 128
max_memory = {key: val * 0.9 for key, val in max_memory.items()}
return max_memory
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
from ..integrations import get_keys_to_not_convert
self.modules_to_not_convert = get_keys_to_not_convert(model)
if self.quantization_config.modules_to_not_convert is not None:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
return
def check_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
param_device = kwargs.pop("param_device", None)
# check if the param_name is not in self.modules_to_not_convert
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
return False
elif param_device == "cpu" and self.offload:
# We don't quantize weights that we offload
return False
else:
# we only quantize the weight of nn.Linear
module, tensor_name = get_module_from_name(model, param_name)
return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
def create_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: List[str],
):
"""
Each nn.Linear layer that needs to be quantized is processsed here.
First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
"""
module, tensor_name = get_module_from_name(model, param_name)
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
def _process_model_after_weight_loading(self, model):
"""No process required for torchao quantized model"""
return
@property
def is_serializable(self):
return False
@property
def is_trainable(self):
# torchao does not have official support for QAT (Quantization Aware Training)
# but torchao support nf4/PEFT, but it is not integrated yet
# TODO: if this is supported in the future, do a version check here.
return False

View File

@@ -127,6 +127,7 @@ from .utils import (
is_torch_tf32_available, is_torch_tf32_available,
is_torch_xla_available, is_torch_xla_available,
is_torch_xpu_available, is_torch_xpu_available,
is_torchao_available,
is_torchaudio_available, is_torchaudio_available,
is_torchdynamo_available, is_torchdynamo_available,
is_torchvision_available, is_torchvision_available,
@@ -910,6 +911,11 @@ def require_torchdynamo(test_case):
return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case) return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
def require_torchao(test_case):
"""Decorator marking a test that requires torchao"""
return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case)
def require_torch_tensorrt_fx(test_case): def require_torch_tensorrt_fx(test_case):
"""Decorator marking a test that requires Torch-TensorRT FX""" """Decorator marking a test that requires Torch-TensorRT FX"""
return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case) return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)

View File

@@ -210,6 +210,7 @@ from .import_utils import (
is_torch_tpu_available, is_torch_tpu_available,
is_torch_xla_available, is_torch_xla_available,
is_torch_xpu_available, is_torch_xpu_available,
is_torchao_available,
is_torchaudio_available, is_torchaudio_available,
is_torchdistx_available, is_torchdistx_available,
is_torchdynamo_available, is_torchdynamo_available,

View File

@@ -172,6 +172,7 @@ _tf2onnx_available = _is_package_available("tf2onnx")
_timm_available = _is_package_available("timm") _timm_available = _is_package_available("timm")
_tokenizers_available = _is_package_available("tokenizers") _tokenizers_available = _is_package_available("tokenizers")
_torchaudio_available = _is_package_available("torchaudio") _torchaudio_available = _is_package_available("torchaudio")
_torchao_available = _is_package_available("torchao")
_torchdistx_available = _is_package_available("torchdistx") _torchdistx_available = _is_package_available("torchdistx")
_torchvision_available = _is_package_available("torchvision") _torchvision_available = _is_package_available("torchvision")
_mlx_available = _is_package_available("mlx") _mlx_available = _is_package_available("mlx")
@@ -1092,6 +1093,10 @@ def is_torchaudio_available():
return _torchaudio_available return _torchaudio_available
def is_torchao_available():
return _torchao_available
def is_speech_available(): def is_speech_available():
# For now this depends on torchaudio but the exact dependency might evolve in the future. # For now this depends on torchaudio but the exact dependency might evolve in the future.
return _torchaudio_available return _torchaudio_available

View File

@@ -20,17 +20,17 @@ import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from inspect import Parameter, signature
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from packaging import version from packaging import version
from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, logging from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, is_torchao_available, logging
if is_torch_available(): if is_torch_available():
import torch import torch
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -43,6 +43,7 @@ class QuantizationMethod(str, Enum):
EETQ = "eetq" EETQ = "eetq"
HQQ = "hqq" HQQ = "hqq"
FBGEMM_FP8 = "fbgemm_fp8" FBGEMM_FP8 = "fbgemm_fp8"
TORCHAO = "torchao"
class AWQLinearVersion(str, Enum): class AWQLinearVersion(str, Enum):
@@ -1079,3 +1080,84 @@ class FbgemmFp8Config(QuantizationConfigMixin):
loading_attibutes = ["activation_scale_ub"] loading_attibutes = ["activation_scale_ub"]
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes} loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
return loading_attibutes_dict return loading_attibutes_dict
@dataclass
class TorchAoConfig(QuantizationConfigMixin):
"""This is a config class for torchao quantization/sparsity techniques.
Args:
quant_type (`str`):
The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`.
modules_to_not_convert (`list`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have
some modules left in their original precision.
kwargs (`Dict[str, Any]`, *optional*):
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments
`group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in
https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
Example:
```python
quantization_config = TorchAoConfig("int4_weight_only", group_size=32)
# int4_weight_only quant is only working with *torch.bfloat16* dtype right now
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
```
"""
def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = None, **kwargs):
self.quant_method = QuantizationMethod.TORCHAO
self.quant_type = quant_type
self.modules_to_not_convert = modules_to_not_convert
self.kwargs = kwargs
self._STR_TO_METHOD = {}
if is_torchao_available():
from torchao.quantization import (
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
)
self._STR_TO_METHOD = {
"int4_weight_only": int4_weight_only,
"int8_weight_only": int8_weight_only,
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
}
else:
raise ValueError(
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
)
self.post_init()
def post_init(self):
r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"):
raise ValueError("Requires torchao 0.4.0 version and above")
if self.quant_type not in self._STR_TO_METHOD.keys():
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported yet, please add support in TorchAoConfig and TorchAoHfQuantizer."
)
method = self._STR_TO_METHOD[self.quant_type]
sig = signature(method)
all_kwargs = [
param.name
for param in sig.parameters.values()
if param.kind in [Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD]
]
for k in self.kwargs:
if k not in all_kwargs:
raise ValueError(
f"Unexpected keyword arg: {k} for API: {method}, accepted keyword args are: {all_kwargs}"
)
def get_apply_tensor_subclass(self):
return self._STR_TO_METHOD[self.quant_type](**self.kwargs)
def __repr__(self):
return f"{self.quant_type}({', '.join(str(k) + '=' + str(v) for k, v in self.kwargs.items())})"

View File

@@ -0,0 +1,213 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
from transformers.testing_utils import (
require_torch_gpu,
require_torch_multi_gpu,
require_torchao,
torch_device,
)
from transformers.utils import is_torch_available, is_torchao_available
if is_torch_available():
import torch
if is_torchao_available():
from torchao.dtypes import AffineQuantizedTensor
from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType
def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024):
weight = qlayer.weight
test_module.assertTrue(isinstance(weight, AffineQuantizedTensor))
test_module.assertEqual(weight.quant_min, 0)
test_module.assertEqual(weight.quant_max, 15)
test_module.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
def check_forward(test_module, model, batch_size=1, context_size=1024):
# Test forward pass
with torch.no_grad():
out = model(torch.zeros([batch_size, context_size], device=model.device, dtype=torch.int32)).logits
test_module.assertEqual(out.shape[0], batch_size)
test_module.assertEqual(out.shape[1], context_size)
@require_torch_gpu
@require_torchao
class TorchAoConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
Makes sure the config format is properly set
"""
quantization_config = TorchAoConfig("int4_weight_only")
torchao_orig_config = quantization_config.to_dict()
for key in torchao_orig_config:
self.assertEqual(getattr(quantization_config, key), torchao_orig_config[key])
def test_post_init_check(self):
"""
Test kwargs validations in TorchAoConfig
"""
_ = TorchAoConfig("int4_weight_only")
with self.assertRaisesRegex(ValueError, "is not supported yet"):
_ = TorchAoConfig("fp6")
with self.assertRaisesRegex(ValueError, "Unexpected keyword arg"):
_ = TorchAoConfig("int4_weight_only", group_size1=32)
@require_torch_gpu
@require_torchao
class TorchAoTest(unittest.TestCase):
input_text = "What are we having for dinner?"
max_new_tokens = 10
EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_int4wo_quant(self):
"""
Simple LLM model testing int4 weight only quantization
"""
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
# Note: we quantize the bfloat16 model on the fly to int4
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
device_map=torch_device,
quantization_config=quant_config,
)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
check_torchao_quantized(self, quantized_model.model.layers[0].self_attn.v_proj)
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_int4wo_quant_bfloat16_conversion(self):
"""
Testing the dtype of model will be modified to be bfloat16 for int4 weight only quantization
"""
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
# Note: we quantize the bfloat16 model on the fly to int4
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=None,
device_map=torch_device,
quantization_config=quant_config,
)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
check_torchao_quantized(self, quantized_model.model.layers[0].self_attn.v_proj)
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_gpu
def test_int4wo_quant_multi_gpu(self):
"""
Simple test that checks if the quantized model int4 wieght only is working properly with multiple GPUs
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS
"""
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=quant_config,
)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_int4wo_offload(self):
"""
Simple test that checks if the quantized model int4 wieght only is working properly with cpu/disk offload
"""
device_map_offload = {
"model.embed_tokens": 0,
"model.layers.0": 0,
"model.layers.1": 0,
"model.layers.2": 0,
"model.layers.3": 0,
"model.layers.4": 0,
"model.layers.5": 0,
"model.layers.6": 0,
"model.layers.7": 0,
"model.layers.8": 0,
"model.layers.9": 0,
"model.layers.10": 0,
"model.layers.11": 0,
"model.layers.12": 0,
"model.layers.13": 0,
"model.layers.14": 0,
"model.layers.15": 0,
"model.layers.16": 0,
"model.layers.17": 0,
"model.layers.18": 0,
"model.layers.19": "cpu",
"model.layers.20": "cpu",
"model.layers.21": "disk",
"model.norm": 0,
"model.rotary_emb": 0,
"lm_head": 0,
}
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
device_map=device_map_offload,
quantization_config=quant_config,
)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
EXPECTED_OUTPUT = "What are we having for dinner?\n- 2. What is the temperature outside"
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
if __name__ == "__main__":
unittest.main()