Add TorchAOHfQuantizer (#32306)
* Add TorchAOHfQuantizer Summary: Enable loading torchao quantized model in huggingface. Test Plan: local test Reviewers: Subscribers: Tasks: Tags: * Fix a few issues * style * Added tests and addressed some comments about dtype conversion * fix torch_dtype warning message * fix tests * style * TorchAOConfig -> TorchAoConfig * enable offload + fix memory with multi-gpu * update torchao version requirement to 0.4.0 * better comments * add torch.compile to torchao README, add perf number link --------- Co-authored-by: Marc Sun <marc@huggingface.co>
This commit is contained in:
@@ -163,6 +163,8 @@
|
||||
title: FBGEMM_FP8
|
||||
- local: quantization/optimum
|
||||
title: Optimum
|
||||
- local: quantization/torchao
|
||||
title: TorchAO
|
||||
- local: quantization/contribute
|
||||
title: Contribute new quantization method
|
||||
title: Quantization Methods
|
||||
|
||||
@@ -61,3 +61,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
|
||||
|
||||
[[autodoc]] FbgemmFp8Config
|
||||
|
||||
## TorchAoConfig
|
||||
|
||||
[[autodoc]] TorchAoConfig
|
||||
|
||||
|
||||
@@ -56,4 +56,4 @@ Use the table below to help you decide which quantization method to use.
|
||||
| [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
|
||||
| [Quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/quanto |
|
||||
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
|
||||
|
||||
| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | partial support (int4 weight only) | | 4 / 8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
|
||||
|
||||
45
docs/source/en/quantization/torchao.md
Normal file
45
docs/source/en/quantization/torchao.md
Normal file
@@ -0,0 +1,45 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
-->
|
||||
|
||||
# TorchAO
|
||||
|
||||
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#without-intrusive-code-changes)
|
||||
|
||||
Before you begin, make sure the following libraries are installed with their latest version:
|
||||
|
||||
```bash
|
||||
pip install --upgrade torch torchao
|
||||
```
|
||||
|
||||
|
||||
```py
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_name = "meta-llama/Meta-Llama-3-8B"
|
||||
# We support int4_weight_only, int8_weight_only and int8_dynamic_activation_int8_weight
|
||||
# More examples and documentations for arguments can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=quantization_config)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
# compile the quantizd model to get speedup
|
||||
import torchao
|
||||
torchao.quantization.utils.recommended_inductor_config_setter()
|
||||
quantized_model = torch.compile(quantized_model, mode="max-autotune")
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
torchao quantization is implemented with tensor subclasses, currently it does not work with huggingface serialization, both the safetensor option and [non-safetensor option](https://github.com/huggingface/transformers/issues/32364), we'll update here with instructions when it's working.
|
||||
@@ -949,6 +949,7 @@ _import_structure = {
|
||||
"GPTQConfig",
|
||||
"HqqConfig",
|
||||
"QuantoConfig",
|
||||
"TorchAoConfig",
|
||||
],
|
||||
}
|
||||
|
||||
@@ -5728,6 +5729,7 @@ if TYPE_CHECKING:
|
||||
GPTQConfig,
|
||||
HqqConfig,
|
||||
QuantoConfig,
|
||||
TorchAoConfig,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -26,6 +26,7 @@ from ..utils.quantization_config import (
|
||||
QuantizationConfigMixin,
|
||||
QuantizationMethod,
|
||||
QuantoConfig,
|
||||
TorchAoConfig,
|
||||
)
|
||||
from .quantizer_aqlm import AqlmHfQuantizer
|
||||
from .quantizer_awq import AwqQuantizer
|
||||
@@ -36,6 +37,7 @@ from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
|
||||
from .quantizer_gptq import GptqHfQuantizer
|
||||
from .quantizer_hqq import HqqHfQuantizer
|
||||
from .quantizer_quanto import QuantoHfQuantizer
|
||||
from .quantizer_torchao import TorchAoHfQuantizer
|
||||
|
||||
|
||||
AUTO_QUANTIZER_MAPPING = {
|
||||
@@ -48,6 +50,7 @@ AUTO_QUANTIZER_MAPPING = {
|
||||
"eetq": EetqHfQuantizer,
|
||||
"hqq": HqqHfQuantizer,
|
||||
"fbgemm_fp8": FbgemmFp8HfQuantizer,
|
||||
"torchao": TorchAoHfQuantizer,
|
||||
}
|
||||
|
||||
AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
@@ -60,6 +63,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
"quanto": QuantoConfig,
|
||||
"hqq": HqqConfig,
|
||||
"fbgemm_fp8": FbgemmFp8Config,
|
||||
"torchao": TorchAoConfig,
|
||||
}
|
||||
|
||||
|
||||
|
||||
172
src/transformers/quantizers/quantizer_torchao.py
Normal file
172
src/transformers/quantizers/quantizer_torchao.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
from .base import HfQuantizer
|
||||
from .quantizers_utils import get_module_from_name
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from ..utils import is_torch_available, is_torchao_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchao_available():
|
||||
from torchao.quantization import quantize_
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Finds the parent of a node module named "name"
|
||||
def find_parent(model, name):
|
||||
module_tree = name.split(".")[:-1]
|
||||
parent = model
|
||||
for m in module_tree:
|
||||
parent = parent._modules[m]
|
||||
return parent
|
||||
|
||||
|
||||
class TorchAoHfQuantizer(HfQuantizer):
|
||||
"""
|
||||
Quantizer for torchao: https://github.com/pytorch/ao/
|
||||
"""
|
||||
|
||||
requires_parameters_quantization = True
|
||||
requires_calibration = False
|
||||
required_packages = ["torchao"]
|
||||
|
||||
def __init__(self, quantization_config, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not is_torchao_available():
|
||||
raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)")
|
||||
|
||||
self.offload = False
|
||||
device_map = kwargs.get("device_map", None)
|
||||
if isinstance(device_map, dict):
|
||||
if "cpu" in device_map.values() or "disk" in device_map.values():
|
||||
if self.pre_quantized:
|
||||
raise ValueError(
|
||||
"You are attempting to perform cpu/disk offload with a pre-quantized torchao model "
|
||||
"This is not supported yet . Please remove the CPU or disk device from the device_map."
|
||||
)
|
||||
else:
|
||||
self.offload = True
|
||||
|
||||
def update_torch_dtype(self, torch_dtype):
|
||||
if self.quantization_config.quant_type == "int4_weight_only":
|
||||
if torch_dtype is not None and torch_dtype != torch.bfloat16:
|
||||
logger.warning_once(
|
||||
f"Setting torch_dtype to {torch_dtype} for int4_weight_only quantization, but only bfloat16 is supported right now. Please set the torch_dtype to bfloat16."
|
||||
)
|
||||
if torch_dtype is None:
|
||||
logger.warning_once(
|
||||
"Setting torch_dtype to torch.bfloat16 for int4_weight_only quantization since only bfloat16 is supported right now. Please set torch_dtype=torch.bfloat16 to remove this warning."
|
||||
)
|
||||
torch_dtype = torch.bfloat16
|
||||
return torch_dtype
|
||||
|
||||
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
|
||||
from accelerate.utils import CustomDtype
|
||||
|
||||
map_to_target_dtype = {
|
||||
"int4_weight_only": CustomDtype.INT4,
|
||||
"int8_weight_only": torch.int8,
|
||||
"int8_dynamic_activation_int8_weight": torch.int8,
|
||||
}
|
||||
return map_to_target_dtype[self.quantization_config.quant_type]
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are using `device_map='auto'` on a torchao quantized model. To automatically compute"
|
||||
" the appropriate device map, you should upgrade your `accelerate` library with "
|
||||
"`pip install --upgrade accelerate`"
|
||||
)
|
||||
|
||||
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
|
||||
# need more space for the quantization parameters (e.g. scale). Tested with int4 wo and group size = 128
|
||||
max_memory = {key: val * 0.9 for key, val in max_memory.items()}
|
||||
return max_memory
|
||||
|
||||
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
from ..integrations import get_keys_to_not_convert
|
||||
|
||||
self.modules_to_not_convert = get_keys_to_not_convert(model)
|
||||
|
||||
if self.quantization_config.modules_to_not_convert is not None:
|
||||
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
|
||||
|
||||
return
|
||||
|
||||
def check_quantized_param(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
param_device = kwargs.pop("param_device", None)
|
||||
# check if the param_name is not in self.modules_to_not_convert
|
||||
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
|
||||
return False
|
||||
elif param_device == "cpu" and self.offload:
|
||||
# We don't quantize weights that we offload
|
||||
return False
|
||||
else:
|
||||
# we only quantize the weight of nn.Linear
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: Dict[str, Any],
|
||||
unexpected_keys: List[str],
|
||||
):
|
||||
"""
|
||||
Each nn.Linear layer that needs to be quantized is processsed here.
|
||||
First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
|
||||
"""
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
|
||||
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
|
||||
|
||||
def _process_model_after_weight_loading(self, model):
|
||||
"""No process required for torchao quantized model"""
|
||||
return
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_trainable(self):
|
||||
# torchao does not have official support for QAT (Quantization Aware Training)
|
||||
# but torchao support nf4/PEFT, but it is not integrated yet
|
||||
# TODO: if this is supported in the future, do a version check here.
|
||||
return False
|
||||
@@ -127,6 +127,7 @@ from .utils import (
|
||||
is_torch_tf32_available,
|
||||
is_torch_xla_available,
|
||||
is_torch_xpu_available,
|
||||
is_torchao_available,
|
||||
is_torchaudio_available,
|
||||
is_torchdynamo_available,
|
||||
is_torchvision_available,
|
||||
@@ -910,6 +911,11 @@ def require_torchdynamo(test_case):
|
||||
return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
|
||||
|
||||
|
||||
def require_torchao(test_case):
|
||||
"""Decorator marking a test that requires torchao"""
|
||||
return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case)
|
||||
|
||||
|
||||
def require_torch_tensorrt_fx(test_case):
|
||||
"""Decorator marking a test that requires Torch-TensorRT FX"""
|
||||
return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)
|
||||
|
||||
@@ -210,6 +210,7 @@ from .import_utils import (
|
||||
is_torch_tpu_available,
|
||||
is_torch_xla_available,
|
||||
is_torch_xpu_available,
|
||||
is_torchao_available,
|
||||
is_torchaudio_available,
|
||||
is_torchdistx_available,
|
||||
is_torchdynamo_available,
|
||||
|
||||
@@ -172,6 +172,7 @@ _tf2onnx_available = _is_package_available("tf2onnx")
|
||||
_timm_available = _is_package_available("timm")
|
||||
_tokenizers_available = _is_package_available("tokenizers")
|
||||
_torchaudio_available = _is_package_available("torchaudio")
|
||||
_torchao_available = _is_package_available("torchao")
|
||||
_torchdistx_available = _is_package_available("torchdistx")
|
||||
_torchvision_available = _is_package_available("torchvision")
|
||||
_mlx_available = _is_package_available("mlx")
|
||||
@@ -1092,6 +1093,10 @@ def is_torchaudio_available():
|
||||
return _torchaudio_available
|
||||
|
||||
|
||||
def is_torchao_available():
|
||||
return _torchao_available
|
||||
|
||||
|
||||
def is_speech_available():
|
||||
# For now this depends on torchaudio but the exact dependency might evolve in the future.
|
||||
return _torchaudio_available
|
||||
|
||||
@@ -20,17 +20,17 @@ import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from inspect import Parameter, signature
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, logging
|
||||
from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, is_torchao_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ class QuantizationMethod(str, Enum):
|
||||
EETQ = "eetq"
|
||||
HQQ = "hqq"
|
||||
FBGEMM_FP8 = "fbgemm_fp8"
|
||||
TORCHAO = "torchao"
|
||||
|
||||
|
||||
class AWQLinearVersion(str, Enum):
|
||||
@@ -1079,3 +1080,84 @@ class FbgemmFp8Config(QuantizationConfigMixin):
|
||||
loading_attibutes = ["activation_scale_ub"]
|
||||
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
|
||||
return loading_attibutes_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class TorchAoConfig(QuantizationConfigMixin):
|
||||
"""This is a config class for torchao quantization/sparsity techniques.
|
||||
|
||||
Args:
|
||||
quant_type (`str`):
|
||||
The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`.
|
||||
modules_to_not_convert (`list`, *optional*, default to `None`):
|
||||
The list of modules to not quantize, useful for quantizing models that explicitly require to have
|
||||
some modules left in their original precision.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments
|
||||
`group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in
|
||||
https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||
# int4_weight_only quant is only working with *torch.bfloat16* dtype right now
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = None, **kwargs):
|
||||
self.quant_method = QuantizationMethod.TORCHAO
|
||||
self.quant_type = quant_type
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
self.kwargs = kwargs
|
||||
self._STR_TO_METHOD = {}
|
||||
if is_torchao_available():
|
||||
from torchao.quantization import (
|
||||
int4_weight_only,
|
||||
int8_dynamic_activation_int8_weight,
|
||||
int8_weight_only,
|
||||
)
|
||||
|
||||
self._STR_TO_METHOD = {
|
||||
"int4_weight_only": int4_weight_only,
|
||||
"int8_weight_only": int8_weight_only,
|
||||
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
|
||||
)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
|
||||
"""
|
||||
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"):
|
||||
raise ValueError("Requires torchao 0.4.0 version and above")
|
||||
|
||||
if self.quant_type not in self._STR_TO_METHOD.keys():
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is not supported yet, please add support in TorchAoConfig and TorchAoHfQuantizer."
|
||||
)
|
||||
|
||||
method = self._STR_TO_METHOD[self.quant_type]
|
||||
sig = signature(method)
|
||||
all_kwargs = [
|
||||
param.name
|
||||
for param in sig.parameters.values()
|
||||
if param.kind in [Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD]
|
||||
]
|
||||
for k in self.kwargs:
|
||||
if k not in all_kwargs:
|
||||
raise ValueError(
|
||||
f"Unexpected keyword arg: {k} for API: {method}, accepted keyword args are: {all_kwargs}"
|
||||
)
|
||||
|
||||
def get_apply_tensor_subclass(self):
|
||||
return self._STR_TO_METHOD[self.quant_type](**self.kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.quant_type}({', '.join(str(k) + '=' + str(v) for k, v in self.kwargs.items())})"
|
||||
|
||||
0
tests/quantization/torchao_integration/__init__.py
Normal file
0
tests/quantization/torchao_integration/__init__.py
Normal file
213
tests/quantization/torchao_integration/test_torchao.py
Normal file
213
tests/quantization/torchao_integration/test_torchao.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||
from transformers.testing_utils import (
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torchao,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_torch_available, is_torchao_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchao_available():
|
||||
from torchao.dtypes import AffineQuantizedTensor
|
||||
from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType
|
||||
|
||||
|
||||
def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024):
|
||||
weight = qlayer.weight
|
||||
test_module.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
test_module.assertEqual(weight.quant_min, 0)
|
||||
test_module.assertEqual(weight.quant_max, 15)
|
||||
test_module.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
|
||||
|
||||
|
||||
def check_forward(test_module, model, batch_size=1, context_size=1024):
|
||||
# Test forward pass
|
||||
with torch.no_grad():
|
||||
out = model(torch.zeros([batch_size, context_size], device=model.device, dtype=torch.int32)).logits
|
||||
test_module.assertEqual(out.shape[0], batch_size)
|
||||
test_module.assertEqual(out.shape[1], context_size)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torchao
|
||||
class TorchAoConfigTest(unittest.TestCase):
|
||||
def test_to_dict(self):
|
||||
"""
|
||||
Makes sure the config format is properly set
|
||||
"""
|
||||
quantization_config = TorchAoConfig("int4_weight_only")
|
||||
torchao_orig_config = quantization_config.to_dict()
|
||||
|
||||
for key in torchao_orig_config:
|
||||
self.assertEqual(getattr(quantization_config, key), torchao_orig_config[key])
|
||||
|
||||
def test_post_init_check(self):
|
||||
"""
|
||||
Test kwargs validations in TorchAoConfig
|
||||
"""
|
||||
_ = TorchAoConfig("int4_weight_only")
|
||||
with self.assertRaisesRegex(ValueError, "is not supported yet"):
|
||||
_ = TorchAoConfig("fp6")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Unexpected keyword arg"):
|
||||
_ = TorchAoConfig("int4_weight_only", group_size1=32)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torchao
|
||||
class TorchAoTest(unittest.TestCase):
|
||||
input_text = "What are we having for dinner?"
|
||||
max_new_tokens = 10
|
||||
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
|
||||
|
||||
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def test_int4wo_quant(self):
|
||||
"""
|
||||
Simple LLM model testing int4 weight only quantization
|
||||
"""
|
||||
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||
|
||||
# Note: we quantize the bfloat16 model on the fly to int4
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=torch_device,
|
||||
quantization_config=quant_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
check_torchao_quantized(self, quantized_model.model.layers[0].self_attn.v_proj)
|
||||
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_int4wo_quant_bfloat16_conversion(self):
|
||||
"""
|
||||
Testing the dtype of model will be modified to be bfloat16 for int4 weight only quantization
|
||||
"""
|
||||
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||
|
||||
# Note: we quantize the bfloat16 model on the fly to int4
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=None,
|
||||
device_map=torch_device,
|
||||
quantization_config=quant_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
check_torchao_quantized(self, quantized_model.model.layers[0].self_attn.v_proj)
|
||||
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_int4wo_quant_multi_gpu(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model int4 wieght only is working properly with multiple GPUs
|
||||
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS
|
||||
"""
|
||||
|
||||
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quant_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
|
||||
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_int4wo_offload(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model int4 wieght only is working properly with cpu/disk offload
|
||||
"""
|
||||
|
||||
device_map_offload = {
|
||||
"model.embed_tokens": 0,
|
||||
"model.layers.0": 0,
|
||||
"model.layers.1": 0,
|
||||
"model.layers.2": 0,
|
||||
"model.layers.3": 0,
|
||||
"model.layers.4": 0,
|
||||
"model.layers.5": 0,
|
||||
"model.layers.6": 0,
|
||||
"model.layers.7": 0,
|
||||
"model.layers.8": 0,
|
||||
"model.layers.9": 0,
|
||||
"model.layers.10": 0,
|
||||
"model.layers.11": 0,
|
||||
"model.layers.12": 0,
|
||||
"model.layers.13": 0,
|
||||
"model.layers.14": 0,
|
||||
"model.layers.15": 0,
|
||||
"model.layers.16": 0,
|
||||
"model.layers.17": 0,
|
||||
"model.layers.18": 0,
|
||||
"model.layers.19": "cpu",
|
||||
"model.layers.20": "cpu",
|
||||
"model.layers.21": "disk",
|
||||
"model.norm": 0,
|
||||
"model.rotary_emb": 0,
|
||||
"lm_head": 0,
|
||||
}
|
||||
|
||||
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device_map_offload,
|
||||
quantization_config=quant_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n- 2. What is the temperature outside"
|
||||
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user