From 78d78cdf8ae0351554eaae4f528c532e3274cf50 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 14 Aug 2024 07:14:24 -0700 Subject: [PATCH] Add TorchAOHfQuantizer (#32306) * Add TorchAOHfQuantizer Summary: Enable loading torchao quantized model in huggingface. Test Plan: local test Reviewers: Subscribers: Tasks: Tags: * Fix a few issues * style * Added tests and addressed some comments about dtype conversion * fix torch_dtype warning message * fix tests * style * TorchAOConfig -> TorchAoConfig * enable offload + fix memory with multi-gpu * update torchao version requirement to 0.4.0 * better comments * add torch.compile to torchao README, add perf number link --------- Co-authored-by: Marc Sun --- docs/source/en/_toctree.yml | 2 + docs/source/en/main_classes/quantization.md | 4 + docs/source/en/quantization/overview.md | 2 +- docs/source/en/quantization/torchao.md | 45 ++++ src/transformers/__init__.py | 2 + src/transformers/quantizers/auto.py | 4 + .../quantizers/quantizer_torchao.py | 172 ++++++++++++++ src/transformers/testing_utils.py | 6 + src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 5 + src/transformers/utils/quantization_config.py | 86 ++++++- .../torchao_integration/__init__.py | 0 .../torchao_integration/test_torchao.py | 213 ++++++++++++++++++ 13 files changed, 539 insertions(+), 3 deletions(-) create mode 100644 docs/source/en/quantization/torchao.md create mode 100644 src/transformers/quantizers/quantizer_torchao.py create mode 100644 tests/quantization/torchao_integration/__init__.py create mode 100644 tests/quantization/torchao_integration/test_torchao.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 2199a10948..9f89c8669d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -163,6 +163,8 @@ title: FBGEMM_FP8 - local: quantization/optimum title: Optimum + - local: quantization/torchao + title: TorchAO - local: quantization/contribute title: Contribute new quantization method title: Quantization Methods diff --git a/docs/source/en/main_classes/quantization.md b/docs/source/en/main_classes/quantization.md index fc5808415c..91d15e6066 100755 --- a/docs/source/en/main_classes/quantization.md +++ b/docs/source/en/main_classes/quantization.md @@ -61,3 +61,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide. [[autodoc]] FbgemmFp8Config +## TorchAoConfig + +[[autodoc]] TorchAoConfig + diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 99fc669e49..9eb74793a1 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -56,4 +56,4 @@ Use the table below to help you decide which quantization method to use. | [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ | | [Quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/quanto | | [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM | - +| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | partial support (int4 weight only) | | 4 / 8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao | diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md new file mode 100644 index 0000000000..c9733497aa --- /dev/null +++ b/docs/source/en/quantization/torchao.md @@ -0,0 +1,45 @@ + + +# TorchAO + +[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#without-intrusive-code-changes) + +Before you begin, make sure the following libraries are installed with their latest version: + +```bash +pip install --upgrade torch torchao +``` + + +```py +from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer + +model_name = "meta-llama/Meta-Llama-3-8B" +# We support int4_weight_only, int8_weight_only and int8_dynamic_activation_int8_weight +# More examples and documentations for arguments can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques +quantization_config = TorchAoConfig("int4_weight_only", group_size=128) +quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=quantization_config) + +tokenizer = AutoTokenizer.from_pretrained(model_name) +input_text = "What are we having for dinner?" +input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") + +# compile the quantizd model to get speedup +import torchao +torchao.quantization.utils.recommended_inductor_config_setter() +quantized_model = torch.compile(quantized_model, mode="max-autotune") + +output = quantized_model.generate(**input_ids, max_new_tokens=10) +print(tokenizer.decode(output[0], skip_special_tokens=True)) +``` + +torchao quantization is implemented with tensor subclasses, currently it does not work with huggingface serialization, both the safetensor option and [non-safetensor option](https://github.com/huggingface/transformers/issues/32364), we'll update here with instructions when it's working. diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index d6ef15d71b..b9bb8592ff 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -949,6 +949,7 @@ _import_structure = { "GPTQConfig", "HqqConfig", "QuantoConfig", + "TorchAoConfig", ], } @@ -5728,6 +5729,7 @@ if TYPE_CHECKING: GPTQConfig, HqqConfig, QuantoConfig, + TorchAoConfig, ) try: diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index 40aa86fc37..eaa0847f71 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -26,6 +26,7 @@ from ..utils.quantization_config import ( QuantizationConfigMixin, QuantizationMethod, QuantoConfig, + TorchAoConfig, ) from .quantizer_aqlm import AqlmHfQuantizer from .quantizer_awq import AwqQuantizer @@ -36,6 +37,7 @@ from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer from .quantizer_gptq import GptqHfQuantizer from .quantizer_hqq import HqqHfQuantizer from .quantizer_quanto import QuantoHfQuantizer +from .quantizer_torchao import TorchAoHfQuantizer AUTO_QUANTIZER_MAPPING = { @@ -48,6 +50,7 @@ AUTO_QUANTIZER_MAPPING = { "eetq": EetqHfQuantizer, "hqq": HqqHfQuantizer, "fbgemm_fp8": FbgemmFp8HfQuantizer, + "torchao": TorchAoHfQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { @@ -60,6 +63,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = { "quanto": QuantoConfig, "hqq": HqqConfig, "fbgemm_fp8": FbgemmFp8Config, + "torchao": TorchAoConfig, } diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py new file mode 100644 index 0000000000..3b5dfff209 --- /dev/null +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -0,0 +1,172 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +from typing import TYPE_CHECKING, Union + +from packaging import version + +from .base import HfQuantizer +from .quantizers_utils import get_module_from_name + + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + +from typing import Any, Dict, List + +from ..utils import is_torch_available, is_torchao_available, logging + + +if is_torch_available(): + import torch + +if is_torchao_available(): + from torchao.quantization import quantize_ + +logger = logging.get_logger(__name__) + + +# Finds the parent of a node module named "name" +def find_parent(model, name): + module_tree = name.split(".")[:-1] + parent = model + for m in module_tree: + parent = parent._modules[m] + return parent + + +class TorchAoHfQuantizer(HfQuantizer): + """ + Quantizer for torchao: https://github.com/pytorch/ao/ + """ + + requires_parameters_quantization = True + requires_calibration = False + required_packages = ["torchao"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + def validate_environment(self, *args, **kwargs): + if not is_torchao_available(): + raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)") + + self.offload = False + device_map = kwargs.get("device_map", None) + if isinstance(device_map, dict): + if "cpu" in device_map.values() or "disk" in device_map.values(): + if self.pre_quantized: + raise ValueError( + "You are attempting to perform cpu/disk offload with a pre-quantized torchao model " + "This is not supported yet . Please remove the CPU or disk device from the device_map." + ) + else: + self.offload = True + + def update_torch_dtype(self, torch_dtype): + if self.quantization_config.quant_type == "int4_weight_only": + if torch_dtype is not None and torch_dtype != torch.bfloat16: + logger.warning_once( + f"Setting torch_dtype to {torch_dtype} for int4_weight_only quantization, but only bfloat16 is supported right now. Please set the torch_dtype to bfloat16." + ) + if torch_dtype is None: + logger.warning_once( + "Setting torch_dtype to torch.bfloat16 for int4_weight_only quantization since only bfloat16 is supported right now. Please set torch_dtype=torch.bfloat16 to remove this warning." + ) + torch_dtype = torch.bfloat16 + return torch_dtype + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"): + from accelerate.utils import CustomDtype + + map_to_target_dtype = { + "int4_weight_only": CustomDtype.INT4, + "int8_weight_only": torch.int8, + "int8_dynamic_activation_int8_weight": torch.int8, + } + return map_to_target_dtype[self.quantization_config.quant_type] + else: + raise ValueError( + "You are using `device_map='auto'` on a torchao quantized model. To automatically compute" + " the appropriate device map, you should upgrade your `accelerate` library with " + "`pip install --upgrade accelerate`" + ) + + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + # need more space for the quantization parameters (e.g. scale). Tested with int4 wo and group size = 128 + max_memory = {key: val * 0.9 for key, val in max_memory.items()} + return max_memory + + def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs): + from ..integrations import get_keys_to_not_convert + + self.modules_to_not_convert = get_keys_to_not_convert(model) + + if self.quantization_config.modules_to_not_convert is not None: + self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) + + return + + def check_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ) -> bool: + param_device = kwargs.pop("param_device", None) + # check if the param_name is not in self.modules_to_not_convert + if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert): + return False + elif param_device == "cpu" and self.offload: + # We don't quantize weights that we offload + return False + else: + # we only quantize the weight of nn.Linear + module, tensor_name = get_module_from_name(model, param_name) + return isinstance(module, torch.nn.Linear) and (tensor_name == "weight") + + def create_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: List[str], + ): + """ + Each nn.Linear layer that needs to be quantized is processsed here. + First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module. + """ + module, tensor_name = get_module_from_name(model, param_name) + module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) + quantize_(module, self.quantization_config.get_apply_tensor_subclass()) + + def _process_model_after_weight_loading(self, model): + """No process required for torchao quantized model""" + return + + @property + def is_serializable(self): + return False + + @property + def is_trainable(self): + # torchao does not have official support for QAT (Quantization Aware Training) + # but torchao support nf4/PEFT, but it is not integrated yet + # TODO: if this is supported in the future, do a version check here. + return False diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 6b04ed7426..8e234d7d1c 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -127,6 +127,7 @@ from .utils import ( is_torch_tf32_available, is_torch_xla_available, is_torch_xpu_available, + is_torchao_available, is_torchaudio_available, is_torchdynamo_available, is_torchvision_available, @@ -910,6 +911,11 @@ def require_torchdynamo(test_case): return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case) +def require_torchao(test_case): + """Decorator marking a test that requires torchao""" + return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case) + + def require_torch_tensorrt_fx(test_case): """Decorator marking a test that requires Torch-TensorRT FX""" return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case) diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 546a69132d..4df06118be 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -210,6 +210,7 @@ from .import_utils import ( is_torch_tpu_available, is_torch_xla_available, is_torch_xpu_available, + is_torchao_available, is_torchaudio_available, is_torchdistx_available, is_torchdynamo_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 8eae679501..6840921ddc 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -172,6 +172,7 @@ _tf2onnx_available = _is_package_available("tf2onnx") _timm_available = _is_package_available("timm") _tokenizers_available = _is_package_available("tokenizers") _torchaudio_available = _is_package_available("torchaudio") +_torchao_available = _is_package_available("torchao") _torchdistx_available = _is_package_available("torchdistx") _torchvision_available = _is_package_available("torchvision") _mlx_available = _is_package_available("mlx") @@ -1092,6 +1093,10 @@ def is_torchaudio_available(): return _torchaudio_available +def is_torchao_available(): + return _torchao_available + + def is_speech_available(): # For now this depends on torchaudio but the exact dependency might evolve in the future. return _torchaudio_available diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 5de8307c3b..f03407d3f7 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -20,17 +20,17 @@ import json import os from dataclasses import dataclass from enum import Enum +from inspect import Parameter, signature from typing import Any, Dict, List, Optional, Union from packaging import version -from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, logging +from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, is_torchao_available, logging if is_torch_available(): import torch - logger = logging.get_logger(__name__) @@ -43,6 +43,7 @@ class QuantizationMethod(str, Enum): EETQ = "eetq" HQQ = "hqq" FBGEMM_FP8 = "fbgemm_fp8" + TORCHAO = "torchao" class AWQLinearVersion(str, Enum): @@ -1079,3 +1080,84 @@ class FbgemmFp8Config(QuantizationConfigMixin): loading_attibutes = ["activation_scale_ub"] loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes} return loading_attibutes_dict + + +@dataclass +class TorchAoConfig(QuantizationConfigMixin): + """This is a config class for torchao quantization/sparsity techniques. + + Args: + quant_type (`str`): + The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`. + modules_to_not_convert (`list`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have + some modules left in their original precision. + kwargs (`Dict[str, Any]`, *optional*): + The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments + `group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in + https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques + + Example: + + ```python + quantization_config = TorchAoConfig("int4_weight_only", group_size=32) + # int4_weight_only quant is only working with *torch.bfloat16* dtype right now + model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config) + ``` + """ + + def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = None, **kwargs): + self.quant_method = QuantizationMethod.TORCHAO + self.quant_type = quant_type + self.modules_to_not_convert = modules_to_not_convert + self.kwargs = kwargs + self._STR_TO_METHOD = {} + if is_torchao_available(): + from torchao.quantization import ( + int4_weight_only, + int8_dynamic_activation_int8_weight, + int8_weight_only, + ) + + self._STR_TO_METHOD = { + "int4_weight_only": int4_weight_only, + "int8_weight_only": int8_weight_only, + "int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight, + } + else: + raise ValueError( + "TorchAoConfig requires torchao to be installed, please install with `pip install torchao`" + ) + + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + """ + if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"): + raise ValueError("Requires torchao 0.4.0 version and above") + + if self.quant_type not in self._STR_TO_METHOD.keys(): + raise ValueError( + f"Requested quantization type: {self.quant_type} is not supported yet, please add support in TorchAoConfig and TorchAoHfQuantizer." + ) + + method = self._STR_TO_METHOD[self.quant_type] + sig = signature(method) + all_kwargs = [ + param.name + for param in sig.parameters.values() + if param.kind in [Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD] + ] + for k in self.kwargs: + if k not in all_kwargs: + raise ValueError( + f"Unexpected keyword arg: {k} for API: {method}, accepted keyword args are: {all_kwargs}" + ) + + def get_apply_tensor_subclass(self): + return self._STR_TO_METHOD[self.quant_type](**self.kwargs) + + def __repr__(self): + return f"{self.quant_type}({', '.join(str(k) + '=' + str(v) for k, v in self.kwargs.items())})" diff --git a/tests/quantization/torchao_integration/__init__.py b/tests/quantization/torchao_integration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py new file mode 100644 index 0000000000..8014f745d0 --- /dev/null +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -0,0 +1,213 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig +from transformers.testing_utils import ( + require_torch_gpu, + require_torch_multi_gpu, + require_torchao, + torch_device, +) +from transformers.utils import is_torch_available, is_torchao_available + + +if is_torch_available(): + import torch + +if is_torchao_available(): + from torchao.dtypes import AffineQuantizedTensor + from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType + + +def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024): + weight = qlayer.weight + test_module.assertTrue(isinstance(weight, AffineQuantizedTensor)) + test_module.assertEqual(weight.quant_min, 0) + test_module.assertEqual(weight.quant_max, 15) + test_module.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType)) + + +def check_forward(test_module, model, batch_size=1, context_size=1024): + # Test forward pass + with torch.no_grad(): + out = model(torch.zeros([batch_size, context_size], device=model.device, dtype=torch.int32)).logits + test_module.assertEqual(out.shape[0], batch_size) + test_module.assertEqual(out.shape[1], context_size) + + +@require_torch_gpu +@require_torchao +class TorchAoConfigTest(unittest.TestCase): + def test_to_dict(self): + """ + Makes sure the config format is properly set + """ + quantization_config = TorchAoConfig("int4_weight_only") + torchao_orig_config = quantization_config.to_dict() + + for key in torchao_orig_config: + self.assertEqual(getattr(quantization_config, key), torchao_orig_config[key]) + + def test_post_init_check(self): + """ + Test kwargs validations in TorchAoConfig + """ + _ = TorchAoConfig("int4_weight_only") + with self.assertRaisesRegex(ValueError, "is not supported yet"): + _ = TorchAoConfig("fp6") + + with self.assertRaisesRegex(ValueError, "Unexpected keyword arg"): + _ = TorchAoConfig("int4_weight_only", group_size1=32) + + +@require_torch_gpu +@require_torchao +class TorchAoTest(unittest.TestCase): + input_text = "What are we having for dinner?" + max_new_tokens = 10 + + EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside" + + model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + def test_int4wo_quant(self): + """ + Simple LLM model testing int4 weight only quantization + """ + quant_config = TorchAoConfig("int4_weight_only", group_size=32) + + # Note: we quantize the bfloat16 model on the fly to int4 + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, + torch_dtype=torch.bfloat16, + device_map=torch_device, + quantization_config=quant_config, + ) + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + check_torchao_quantized(self, quantized_model.model.layers[0].self_attn.v_proj) + + input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + def test_int4wo_quant_bfloat16_conversion(self): + """ + Testing the dtype of model will be modified to be bfloat16 for int4 weight only quantization + """ + quant_config = TorchAoConfig("int4_weight_only", group_size=32) + + # Note: we quantize the bfloat16 model on the fly to int4 + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, + torch_dtype=None, + device_map=torch_device, + quantization_config=quant_config, + ) + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + check_torchao_quantized(self, quantized_model.model.layers[0].self_attn.v_proj) + + input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + @require_torch_multi_gpu + def test_int4wo_quant_multi_gpu(self): + """ + Simple test that checks if the quantized model int4 wieght only is working properly with multiple GPUs + set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS + """ + + quant_config = TorchAoConfig("int4_weight_only", group_size=32) + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, + torch_dtype=torch.bfloat16, + device_map="auto", + quantization_config=quant_config, + ) + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) + + input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + def test_int4wo_offload(self): + """ + Simple test that checks if the quantized model int4 wieght only is working properly with cpu/disk offload + """ + + device_map_offload = { + "model.embed_tokens": 0, + "model.layers.0": 0, + "model.layers.1": 0, + "model.layers.2": 0, + "model.layers.3": 0, + "model.layers.4": 0, + "model.layers.5": 0, + "model.layers.6": 0, + "model.layers.7": 0, + "model.layers.8": 0, + "model.layers.9": 0, + "model.layers.10": 0, + "model.layers.11": 0, + "model.layers.12": 0, + "model.layers.13": 0, + "model.layers.14": 0, + "model.layers.15": 0, + "model.layers.16": 0, + "model.layers.17": 0, + "model.layers.18": 0, + "model.layers.19": "cpu", + "model.layers.20": "cpu", + "model.layers.21": "disk", + "model.norm": 0, + "model.rotary_emb": 0, + "lm_head": 0, + } + + quant_config = TorchAoConfig("int4_weight_only", group_size=32) + + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, + torch_dtype=torch.bfloat16, + device_map=device_map_offload, + quantization_config=quant_config, + ) + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + EXPECTED_OUTPUT = "What are we having for dinner?\n- 2. What is the temperature outside" + + self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT) + + +if __name__ == "__main__": + unittest.main()