Add TorchAOHfQuantizer (#32306)
* Add TorchAOHfQuantizer Summary: Enable loading torchao quantized model in huggingface. Test Plan: local test Reviewers: Subscribers: Tasks: Tags: * Fix a few issues * style * Added tests and addressed some comments about dtype conversion * fix torch_dtype warning message * fix tests * style * TorchAOConfig -> TorchAoConfig * enable offload + fix memory with multi-gpu * update torchao version requirement to 0.4.0 * better comments * add torch.compile to torchao README, add perf number link --------- Co-authored-by: Marc Sun <marc@huggingface.co>
This commit is contained in:
@@ -163,6 +163,8 @@
|
|||||||
title: FBGEMM_FP8
|
title: FBGEMM_FP8
|
||||||
- local: quantization/optimum
|
- local: quantization/optimum
|
||||||
title: Optimum
|
title: Optimum
|
||||||
|
- local: quantization/torchao
|
||||||
|
title: TorchAO
|
||||||
- local: quantization/contribute
|
- local: quantization/contribute
|
||||||
title: Contribute new quantization method
|
title: Contribute new quantization method
|
||||||
title: Quantization Methods
|
title: Quantization Methods
|
||||||
|
|||||||
@@ -61,3 +61,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
|
|||||||
|
|
||||||
[[autodoc]] FbgemmFp8Config
|
[[autodoc]] FbgemmFp8Config
|
||||||
|
|
||||||
|
## TorchAoConfig
|
||||||
|
|
||||||
|
[[autodoc]] TorchAoConfig
|
||||||
|
|
||||||
|
|||||||
@@ -56,4 +56,4 @@ Use the table below to help you decide which quantization method to use.
|
|||||||
| [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
|
| [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
|
||||||
| [Quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/quanto |
|
| [Quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/quanto |
|
||||||
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
|
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
|
||||||
|
| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | partial support (int4 weight only) | | 4 / 8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
|
||||||
|
|||||||
45
docs/source/en/quantization/torchao.md
Normal file
45
docs/source/en/quantization/torchao.md
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||||
|
rendered properly in your Markdown viewer.
|
||||||
|
-->
|
||||||
|
|
||||||
|
# TorchAO
|
||||||
|
|
||||||
|
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#without-intrusive-code-changes)
|
||||||
|
|
||||||
|
Before you begin, make sure the following libraries are installed with their latest version:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install --upgrade torch torchao
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
```py
|
||||||
|
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
model_name = "meta-llama/Meta-Llama-3-8B"
|
||||||
|
# We support int4_weight_only, int8_weight_only and int8_dynamic_activation_int8_weight
|
||||||
|
# More examples and documentations for arguments can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
|
||||||
|
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
|
||||||
|
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=quantization_config)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
input_text = "What are we having for dinner?"
|
||||||
|
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||||
|
|
||||||
|
# compile the quantizd model to get speedup
|
||||||
|
import torchao
|
||||||
|
torchao.quantization.utils.recommended_inductor_config_setter()
|
||||||
|
quantized_model = torch.compile(quantized_model, mode="max-autotune")
|
||||||
|
|
||||||
|
output = quantized_model.generate(**input_ids, max_new_tokens=10)
|
||||||
|
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||||
|
```
|
||||||
|
|
||||||
|
torchao quantization is implemented with tensor subclasses, currently it does not work with huggingface serialization, both the safetensor option and [non-safetensor option](https://github.com/huggingface/transformers/issues/32364), we'll update here with instructions when it's working.
|
||||||
@@ -949,6 +949,7 @@ _import_structure = {
|
|||||||
"GPTQConfig",
|
"GPTQConfig",
|
||||||
"HqqConfig",
|
"HqqConfig",
|
||||||
"QuantoConfig",
|
"QuantoConfig",
|
||||||
|
"TorchAoConfig",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -5728,6 +5729,7 @@ if TYPE_CHECKING:
|
|||||||
GPTQConfig,
|
GPTQConfig,
|
||||||
HqqConfig,
|
HqqConfig,
|
||||||
QuantoConfig,
|
QuantoConfig,
|
||||||
|
TorchAoConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from ..utils.quantization_config import (
|
|||||||
QuantizationConfigMixin,
|
QuantizationConfigMixin,
|
||||||
QuantizationMethod,
|
QuantizationMethod,
|
||||||
QuantoConfig,
|
QuantoConfig,
|
||||||
|
TorchAoConfig,
|
||||||
)
|
)
|
||||||
from .quantizer_aqlm import AqlmHfQuantizer
|
from .quantizer_aqlm import AqlmHfQuantizer
|
||||||
from .quantizer_awq import AwqQuantizer
|
from .quantizer_awq import AwqQuantizer
|
||||||
@@ -36,6 +37,7 @@ from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
|
|||||||
from .quantizer_gptq import GptqHfQuantizer
|
from .quantizer_gptq import GptqHfQuantizer
|
||||||
from .quantizer_hqq import HqqHfQuantizer
|
from .quantizer_hqq import HqqHfQuantizer
|
||||||
from .quantizer_quanto import QuantoHfQuantizer
|
from .quantizer_quanto import QuantoHfQuantizer
|
||||||
|
from .quantizer_torchao import TorchAoHfQuantizer
|
||||||
|
|
||||||
|
|
||||||
AUTO_QUANTIZER_MAPPING = {
|
AUTO_QUANTIZER_MAPPING = {
|
||||||
@@ -48,6 +50,7 @@ AUTO_QUANTIZER_MAPPING = {
|
|||||||
"eetq": EetqHfQuantizer,
|
"eetq": EetqHfQuantizer,
|
||||||
"hqq": HqqHfQuantizer,
|
"hqq": HqqHfQuantizer,
|
||||||
"fbgemm_fp8": FbgemmFp8HfQuantizer,
|
"fbgemm_fp8": FbgemmFp8HfQuantizer,
|
||||||
|
"torchao": TorchAoHfQuantizer,
|
||||||
}
|
}
|
||||||
|
|
||||||
AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||||
@@ -60,6 +63,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
|||||||
"quanto": QuantoConfig,
|
"quanto": QuantoConfig,
|
||||||
"hqq": HqqConfig,
|
"hqq": HqqConfig,
|
||||||
"fbgemm_fp8": FbgemmFp8Config,
|
"fbgemm_fp8": FbgemmFp8Config,
|
||||||
|
"torchao": TorchAoConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
172
src/transformers/quantizers/quantizer_torchao.py
Normal file
172
src/transformers/quantizers/quantizer_torchao.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import importlib
|
||||||
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from .base import HfQuantizer
|
||||||
|
from .quantizers_utils import get_module_from_name
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from ..utils import is_torch_available, is_torchao_available, logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if is_torchao_available():
|
||||||
|
from torchao.quantization import quantize_
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Finds the parent of a node module named "name"
|
||||||
|
def find_parent(model, name):
|
||||||
|
module_tree = name.split(".")[:-1]
|
||||||
|
parent = model
|
||||||
|
for m in module_tree:
|
||||||
|
parent = parent._modules[m]
|
||||||
|
return parent
|
||||||
|
|
||||||
|
|
||||||
|
class TorchAoHfQuantizer(HfQuantizer):
|
||||||
|
"""
|
||||||
|
Quantizer for torchao: https://github.com/pytorch/ao/
|
||||||
|
"""
|
||||||
|
|
||||||
|
requires_parameters_quantization = True
|
||||||
|
requires_calibration = False
|
||||||
|
required_packages = ["torchao"]
|
||||||
|
|
||||||
|
def __init__(self, quantization_config, **kwargs):
|
||||||
|
super().__init__(quantization_config, **kwargs)
|
||||||
|
|
||||||
|
def validate_environment(self, *args, **kwargs):
|
||||||
|
if not is_torchao_available():
|
||||||
|
raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)")
|
||||||
|
|
||||||
|
self.offload = False
|
||||||
|
device_map = kwargs.get("device_map", None)
|
||||||
|
if isinstance(device_map, dict):
|
||||||
|
if "cpu" in device_map.values() or "disk" in device_map.values():
|
||||||
|
if self.pre_quantized:
|
||||||
|
raise ValueError(
|
||||||
|
"You are attempting to perform cpu/disk offload with a pre-quantized torchao model "
|
||||||
|
"This is not supported yet . Please remove the CPU or disk device from the device_map."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.offload = True
|
||||||
|
|
||||||
|
def update_torch_dtype(self, torch_dtype):
|
||||||
|
if self.quantization_config.quant_type == "int4_weight_only":
|
||||||
|
if torch_dtype is not None and torch_dtype != torch.bfloat16:
|
||||||
|
logger.warning_once(
|
||||||
|
f"Setting torch_dtype to {torch_dtype} for int4_weight_only quantization, but only bfloat16 is supported right now. Please set the torch_dtype to bfloat16."
|
||||||
|
)
|
||||||
|
if torch_dtype is None:
|
||||||
|
logger.warning_once(
|
||||||
|
"Setting torch_dtype to torch.bfloat16 for int4_weight_only quantization since only bfloat16 is supported right now. Please set torch_dtype=torch.bfloat16 to remove this warning."
|
||||||
|
)
|
||||||
|
torch_dtype = torch.bfloat16
|
||||||
|
return torch_dtype
|
||||||
|
|
||||||
|
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
||||||
|
if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
|
||||||
|
from accelerate.utils import CustomDtype
|
||||||
|
|
||||||
|
map_to_target_dtype = {
|
||||||
|
"int4_weight_only": CustomDtype.INT4,
|
||||||
|
"int8_weight_only": torch.int8,
|
||||||
|
"int8_dynamic_activation_int8_weight": torch.int8,
|
||||||
|
}
|
||||||
|
return map_to_target_dtype[self.quantization_config.quant_type]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"You are using `device_map='auto'` on a torchao quantized model. To automatically compute"
|
||||||
|
" the appropriate device map, you should upgrade your `accelerate` library with "
|
||||||
|
"`pip install --upgrade accelerate`"
|
||||||
|
)
|
||||||
|
|
||||||
|
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
|
||||||
|
# need more space for the quantization parameters (e.g. scale). Tested with int4 wo and group size = 128
|
||||||
|
max_memory = {key: val * 0.9 for key, val in max_memory.items()}
|
||||||
|
return max_memory
|
||||||
|
|
||||||
|
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||||
|
from ..integrations import get_keys_to_not_convert
|
||||||
|
|
||||||
|
self.modules_to_not_convert = get_keys_to_not_convert(model)
|
||||||
|
|
||||||
|
if self.quantization_config.modules_to_not_convert is not None:
|
||||||
|
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def check_quantized_param(
|
||||||
|
self,
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
param_value: "torch.Tensor",
|
||||||
|
param_name: str,
|
||||||
|
state_dict: Dict[str, Any],
|
||||||
|
**kwargs,
|
||||||
|
) -> bool:
|
||||||
|
param_device = kwargs.pop("param_device", None)
|
||||||
|
# check if the param_name is not in self.modules_to_not_convert
|
||||||
|
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
|
||||||
|
return False
|
||||||
|
elif param_device == "cpu" and self.offload:
|
||||||
|
# We don't quantize weights that we offload
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# we only quantize the weight of nn.Linear
|
||||||
|
module, tensor_name = get_module_from_name(model, param_name)
|
||||||
|
return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
|
||||||
|
|
||||||
|
def create_quantized_param(
|
||||||
|
self,
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
param_value: "torch.Tensor",
|
||||||
|
param_name: str,
|
||||||
|
target_device: "torch.device",
|
||||||
|
state_dict: Dict[str, Any],
|
||||||
|
unexpected_keys: List[str],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Each nn.Linear layer that needs to be quantized is processsed here.
|
||||||
|
First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
|
||||||
|
"""
|
||||||
|
module, tensor_name = get_module_from_name(model, param_name)
|
||||||
|
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
|
||||||
|
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
|
||||||
|
|
||||||
|
def _process_model_after_weight_loading(self, model):
|
||||||
|
"""No process required for torchao quantized model"""
|
||||||
|
return
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_serializable(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_trainable(self):
|
||||||
|
# torchao does not have official support for QAT (Quantization Aware Training)
|
||||||
|
# but torchao support nf4/PEFT, but it is not integrated yet
|
||||||
|
# TODO: if this is supported in the future, do a version check here.
|
||||||
|
return False
|
||||||
@@ -127,6 +127,7 @@ from .utils import (
|
|||||||
is_torch_tf32_available,
|
is_torch_tf32_available,
|
||||||
is_torch_xla_available,
|
is_torch_xla_available,
|
||||||
is_torch_xpu_available,
|
is_torch_xpu_available,
|
||||||
|
is_torchao_available,
|
||||||
is_torchaudio_available,
|
is_torchaudio_available,
|
||||||
is_torchdynamo_available,
|
is_torchdynamo_available,
|
||||||
is_torchvision_available,
|
is_torchvision_available,
|
||||||
@@ -910,6 +911,11 @@ def require_torchdynamo(test_case):
|
|||||||
return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
|
return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_torchao(test_case):
|
||||||
|
"""Decorator marking a test that requires torchao"""
|
||||||
|
return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_torch_tensorrt_fx(test_case):
|
def require_torch_tensorrt_fx(test_case):
|
||||||
"""Decorator marking a test that requires Torch-TensorRT FX"""
|
"""Decorator marking a test that requires Torch-TensorRT FX"""
|
||||||
return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)
|
return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)
|
||||||
|
|||||||
@@ -210,6 +210,7 @@ from .import_utils import (
|
|||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
is_torch_xla_available,
|
is_torch_xla_available,
|
||||||
is_torch_xpu_available,
|
is_torch_xpu_available,
|
||||||
|
is_torchao_available,
|
||||||
is_torchaudio_available,
|
is_torchaudio_available,
|
||||||
is_torchdistx_available,
|
is_torchdistx_available,
|
||||||
is_torchdynamo_available,
|
is_torchdynamo_available,
|
||||||
|
|||||||
@@ -172,6 +172,7 @@ _tf2onnx_available = _is_package_available("tf2onnx")
|
|||||||
_timm_available = _is_package_available("timm")
|
_timm_available = _is_package_available("timm")
|
||||||
_tokenizers_available = _is_package_available("tokenizers")
|
_tokenizers_available = _is_package_available("tokenizers")
|
||||||
_torchaudio_available = _is_package_available("torchaudio")
|
_torchaudio_available = _is_package_available("torchaudio")
|
||||||
|
_torchao_available = _is_package_available("torchao")
|
||||||
_torchdistx_available = _is_package_available("torchdistx")
|
_torchdistx_available = _is_package_available("torchdistx")
|
||||||
_torchvision_available = _is_package_available("torchvision")
|
_torchvision_available = _is_package_available("torchvision")
|
||||||
_mlx_available = _is_package_available("mlx")
|
_mlx_available = _is_package_available("mlx")
|
||||||
@@ -1092,6 +1093,10 @@ def is_torchaudio_available():
|
|||||||
return _torchaudio_available
|
return _torchaudio_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_torchao_available():
|
||||||
|
return _torchao_available
|
||||||
|
|
||||||
|
|
||||||
def is_speech_available():
|
def is_speech_available():
|
||||||
# For now this depends on torchaudio but the exact dependency might evolve in the future.
|
# For now this depends on torchaudio but the exact dependency might evolve in the future.
|
||||||
return _torchaudio_available
|
return _torchaudio_available
|
||||||
|
|||||||
@@ -20,17 +20,17 @@ import json
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from inspect import Parameter, signature
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, logging
|
from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, is_torchao_available, logging
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -43,6 +43,7 @@ class QuantizationMethod(str, Enum):
|
|||||||
EETQ = "eetq"
|
EETQ = "eetq"
|
||||||
HQQ = "hqq"
|
HQQ = "hqq"
|
||||||
FBGEMM_FP8 = "fbgemm_fp8"
|
FBGEMM_FP8 = "fbgemm_fp8"
|
||||||
|
TORCHAO = "torchao"
|
||||||
|
|
||||||
|
|
||||||
class AWQLinearVersion(str, Enum):
|
class AWQLinearVersion(str, Enum):
|
||||||
@@ -1079,3 +1080,84 @@ class FbgemmFp8Config(QuantizationConfigMixin):
|
|||||||
loading_attibutes = ["activation_scale_ub"]
|
loading_attibutes = ["activation_scale_ub"]
|
||||||
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
|
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
|
||||||
return loading_attibutes_dict
|
return loading_attibutes_dict
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TorchAoConfig(QuantizationConfigMixin):
|
||||||
|
"""This is a config class for torchao quantization/sparsity techniques.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_type (`str`):
|
||||||
|
The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`.
|
||||||
|
modules_to_not_convert (`list`, *optional*, default to `None`):
|
||||||
|
The list of modules to not quantize, useful for quantizing models that explicitly require to have
|
||||||
|
some modules left in their original precision.
|
||||||
|
kwargs (`Dict[str, Any]`, *optional*):
|
||||||
|
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments
|
||||||
|
`group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in
|
||||||
|
https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
quantization_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||||
|
# int4_weight_only quant is only working with *torch.bfloat16* dtype right now
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = None, **kwargs):
|
||||||
|
self.quant_method = QuantizationMethod.TORCHAO
|
||||||
|
self.quant_type = quant_type
|
||||||
|
self.modules_to_not_convert = modules_to_not_convert
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self._STR_TO_METHOD = {}
|
||||||
|
if is_torchao_available():
|
||||||
|
from torchao.quantization import (
|
||||||
|
int4_weight_only,
|
||||||
|
int8_dynamic_activation_int8_weight,
|
||||||
|
int8_weight_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._STR_TO_METHOD = {
|
||||||
|
"int4_weight_only": int4_weight_only,
|
||||||
|
"int8_weight_only": int8_weight_only,
|
||||||
|
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def post_init(self):
|
||||||
|
r"""
|
||||||
|
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
|
||||||
|
"""
|
||||||
|
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"):
|
||||||
|
raise ValueError("Requires torchao 0.4.0 version and above")
|
||||||
|
|
||||||
|
if self.quant_type not in self._STR_TO_METHOD.keys():
|
||||||
|
raise ValueError(
|
||||||
|
f"Requested quantization type: {self.quant_type} is not supported yet, please add support in TorchAoConfig and TorchAoHfQuantizer."
|
||||||
|
)
|
||||||
|
|
||||||
|
method = self._STR_TO_METHOD[self.quant_type]
|
||||||
|
sig = signature(method)
|
||||||
|
all_kwargs = [
|
||||||
|
param.name
|
||||||
|
for param in sig.parameters.values()
|
||||||
|
if param.kind in [Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD]
|
||||||
|
]
|
||||||
|
for k in self.kwargs:
|
||||||
|
if k not in all_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected keyword arg: {k} for API: {method}, accepted keyword args are: {all_kwargs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_apply_tensor_subclass(self):
|
||||||
|
return self._STR_TO_METHOD[self.quant_type](**self.kwargs)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.quant_type}({', '.join(str(k) + '=' + str(v) for k, v in self.kwargs.items())})"
|
||||||
|
|||||||
0
tests/quantization/torchao_integration/__init__.py
Normal file
0
tests/quantization/torchao_integration/__init__.py
Normal file
213
tests/quantization/torchao_integration/test_torchao.py
Normal file
213
tests/quantization/torchao_integration/test_torchao.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
require_torch_gpu,
|
||||||
|
require_torch_multi_gpu,
|
||||||
|
require_torchao,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
from transformers.utils import is_torch_available, is_torchao_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if is_torchao_available():
|
||||||
|
from torchao.dtypes import AffineQuantizedTensor
|
||||||
|
from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType
|
||||||
|
|
||||||
|
|
||||||
|
def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024):
|
||||||
|
weight = qlayer.weight
|
||||||
|
test_module.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||||
|
test_module.assertEqual(weight.quant_min, 0)
|
||||||
|
test_module.assertEqual(weight.quant_max, 15)
|
||||||
|
test_module.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
|
||||||
|
|
||||||
|
|
||||||
|
def check_forward(test_module, model, batch_size=1, context_size=1024):
|
||||||
|
# Test forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
out = model(torch.zeros([batch_size, context_size], device=model.device, dtype=torch.int32)).logits
|
||||||
|
test_module.assertEqual(out.shape[0], batch_size)
|
||||||
|
test_module.assertEqual(out.shape[1], context_size)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_torchao
|
||||||
|
class TorchAoConfigTest(unittest.TestCase):
|
||||||
|
def test_to_dict(self):
|
||||||
|
"""
|
||||||
|
Makes sure the config format is properly set
|
||||||
|
"""
|
||||||
|
quantization_config = TorchAoConfig("int4_weight_only")
|
||||||
|
torchao_orig_config = quantization_config.to_dict()
|
||||||
|
|
||||||
|
for key in torchao_orig_config:
|
||||||
|
self.assertEqual(getattr(quantization_config, key), torchao_orig_config[key])
|
||||||
|
|
||||||
|
def test_post_init_check(self):
|
||||||
|
"""
|
||||||
|
Test kwargs validations in TorchAoConfig
|
||||||
|
"""
|
||||||
|
_ = TorchAoConfig("int4_weight_only")
|
||||||
|
with self.assertRaisesRegex(ValueError, "is not supported yet"):
|
||||||
|
_ = TorchAoConfig("fp6")
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "Unexpected keyword arg"):
|
||||||
|
_ = TorchAoConfig("int4_weight_only", group_size1=32)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_torchao
|
||||||
|
class TorchAoTest(unittest.TestCase):
|
||||||
|
input_text = "What are we having for dinner?"
|
||||||
|
max_new_tokens = 10
|
||||||
|
|
||||||
|
EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
|
||||||
|
|
||||||
|
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
def test_int4wo_quant(self):
|
||||||
|
"""
|
||||||
|
Simple LLM model testing int4 weight only quantization
|
||||||
|
"""
|
||||||
|
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||||
|
|
||||||
|
# Note: we quantize the bfloat16 model on the fly to int4
|
||||||
|
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=torch_device,
|
||||||
|
quantization_config=quant_config,
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||||
|
|
||||||
|
check_torchao_quantized(self, quantized_model.model.layers[0].self_attn.v_proj)
|
||||||
|
|
||||||
|
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||||
|
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
def test_int4wo_quant_bfloat16_conversion(self):
|
||||||
|
"""
|
||||||
|
Testing the dtype of model will be modified to be bfloat16 for int4 weight only quantization
|
||||||
|
"""
|
||||||
|
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||||
|
|
||||||
|
# Note: we quantize the bfloat16 model on the fly to int4
|
||||||
|
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
torch_dtype=None,
|
||||||
|
device_map=torch_device,
|
||||||
|
quantization_config=quant_config,
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||||
|
|
||||||
|
check_torchao_quantized(self, quantized_model.model.layers[0].self_attn.v_proj)
|
||||||
|
|
||||||
|
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||||
|
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
def test_int4wo_quant_multi_gpu(self):
|
||||||
|
"""
|
||||||
|
Simple test that checks if the quantized model int4 wieght only is working properly with multiple GPUs
|
||||||
|
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS
|
||||||
|
"""
|
||||||
|
|
||||||
|
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||||
|
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="auto",
|
||||||
|
quantization_config=quant_config,
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||||
|
|
||||||
|
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
|
||||||
|
|
||||||
|
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||||
|
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
def test_int4wo_offload(self):
|
||||||
|
"""
|
||||||
|
Simple test that checks if the quantized model int4 wieght only is working properly with cpu/disk offload
|
||||||
|
"""
|
||||||
|
|
||||||
|
device_map_offload = {
|
||||||
|
"model.embed_tokens": 0,
|
||||||
|
"model.layers.0": 0,
|
||||||
|
"model.layers.1": 0,
|
||||||
|
"model.layers.2": 0,
|
||||||
|
"model.layers.3": 0,
|
||||||
|
"model.layers.4": 0,
|
||||||
|
"model.layers.5": 0,
|
||||||
|
"model.layers.6": 0,
|
||||||
|
"model.layers.7": 0,
|
||||||
|
"model.layers.8": 0,
|
||||||
|
"model.layers.9": 0,
|
||||||
|
"model.layers.10": 0,
|
||||||
|
"model.layers.11": 0,
|
||||||
|
"model.layers.12": 0,
|
||||||
|
"model.layers.13": 0,
|
||||||
|
"model.layers.14": 0,
|
||||||
|
"model.layers.15": 0,
|
||||||
|
"model.layers.16": 0,
|
||||||
|
"model.layers.17": 0,
|
||||||
|
"model.layers.18": 0,
|
||||||
|
"model.layers.19": "cpu",
|
||||||
|
"model.layers.20": "cpu",
|
||||||
|
"model.layers.21": "disk",
|
||||||
|
"model.norm": 0,
|
||||||
|
"model.rotary_emb": 0,
|
||||||
|
"lm_head": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||||
|
|
||||||
|
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=device_map_offload,
|
||||||
|
quantization_config=quant_config,
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||||
|
|
||||||
|
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||||
|
EXPECTED_OUTPUT = "What are we having for dinner?\n- 2. What is the temperature outside"
|
||||||
|
|
||||||
|
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user