diff --git a/docker/transformers-quantization-latest-gpu/Dockerfile b/docker/transformers-quantization-latest-gpu/Dockerfile index e1d084c403..08bc3c45b9 100644 --- a/docker/transformers-quantization-latest-gpu/Dockerfile +++ b/docker/transformers-quantization-latest-gpu/Dockerfile @@ -52,6 +52,9 @@ RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoA # Add quanto for quantization testing RUN python3 -m pip install --no-cache-dir quanto +# Add eetq for quantization testing +RUN python3 -m pip install git+https://github.com/NetEase-FuXi/EETQ.git + # When installing in editable mode, `transformers` is not recognized as a package. # this line must be added in order for python to be aware of transformers. RUN cd transformers && python3 setup.py develop \ No newline at end of file diff --git a/docs/source/en/main_classes/quantization.md b/docs/source/en/main_classes/quantization.md index d74e6861d2..91de5fc8a3 100644 --- a/docs/source/en/main_classes/quantization.md +++ b/docs/source/en/main_classes/quantization.md @@ -38,6 +38,9 @@ Learn how to quantize models in the [Quantization](../quantization) guide. [[autodoc]] AwqConfig +## EetqConfig +[[autodoc]] EetqConfig + ## GPTQConfig [[autodoc]] GPTQConfig diff --git a/docs/source/en/quantization.md b/docs/source/en/quantization.md index a6fa2f1f8c..8a3650a843 100644 --- a/docs/source/en/quantization.md +++ b/docs/source/en/quantization.md @@ -642,6 +642,37 @@ double_quant_config = BitsAndBytesConfig( model_double_quant = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b", quantization_config=double_quant_config) ``` +## EETQ +The [EETQ](https://github.com/NetEase-FuXi/EETQ) library supports int8 per-channel weight-only quantization for NVIDIA GPUS. The high-performance GEMM and GEMV kernels are from FasterTransformer and TensorRT-LLM. It requires no calibration dataset and does not need to pre-quantize your model. Moreover, the accuracy degradation is negligible owing to the per-channel quantization. + +Make sure you have eetq installed from the [relase page](https://github.com/NetEase-FuXi/EETQ/releases) +``` +pip install --no-cache-dir https://github.com/NetEase-FuXi/EETQ/releases/download/v1.0.0/EETQ-1.0.0+cu121+torch2.1.2-cp310-cp310-linux_x86_64.whl +``` +or via the source code https://github.com/NetEase-FuXi/EETQ. EETQ requires CUDA capability <= 8.9 and >= 7.0 +``` +git clone https://github.com/NetEase-FuXi/EETQ.git +cd EETQ/ +git submodule update --init --recursive +pip install . +``` + +An unquantized model can be quantized via "from_pretrained". +```py +from transformers import AutoModelForCausalLM, EetqConfig +path = "/path/to/model" +quantization_config = EetqConfig("int8") +model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", quantization_config=quantization_config) +``` + +A quantized model can be saved via "saved_pretrained" and be reused again via the "from_pretrained". + +```py +quant_path = "/path/to/save/quantized/model" +model.save_pretrained(quant_path) +model = AutoModelForCausalLM.from_pretrained(quant_path, device_map="auto") +``` + ## Optimum The [Optimum](https://huggingface.co/docs/optimum/index) library supports quantization for Intel, Furiosa, ONNX Runtime, GPTQ, and lower-level PyTorch quantization functions. Consider using Optimum for quantization if you're using specific and optimized hardware like Intel CPUs, Furiosa NPUs or a model accelerator like ONNX Runtime. diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index c07e3d8f1b..3ce3e057a2 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1126,7 +1126,14 @@ _import_structure = { "is_vision_available", "logging", ], - "utils.quantization_config": ["AqlmConfig", "AwqConfig", "BitsAndBytesConfig", "GPTQConfig", "QuantoConfig"], + "utils.quantization_config": [ + "AqlmConfig", + "AwqConfig", + "BitsAndBytesConfig", + "EetqConfig", + "GPTQConfig", + "QuantoConfig", + ], } # sentencepiece-backed objects @@ -6071,7 +6078,14 @@ if TYPE_CHECKING: ) # bitsandbytes config - from .utils.quantization_config import AqlmConfig, AwqConfig, BitsAndBytesConfig, GPTQConfig, QuantoConfig + from .utils.quantization_config import ( + AqlmConfig, + AwqConfig, + BitsAndBytesConfig, + EetqConfig, + GPTQConfig, + QuantoConfig, + ) try: if not is_sentencepiece_available(): diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 0dc2975aa9..72fdf3e1bb 100644 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -42,6 +42,7 @@ _import_structure = { "set_hf_deepspeed_config", "unset_hf_deepspeed_config", ], + "eetq": ["replace_with_eetq_linear"], "integration_utils": [ "INTEGRATION_TO_CALLBACK", "AzureMLCallback", @@ -111,6 +112,7 @@ if TYPE_CHECKING: set_hf_deepspeed_config, unset_hf_deepspeed_config, ) + from .eetq import replace_with_eetq_linear from .integration_utils import ( INTEGRATION_TO_CALLBACK, AzureMLCallback, diff --git a/src/transformers/integrations/eetq.py b/src/transformers/integrations/eetq.py new file mode 100644 index 0000000000..97698cf1aa --- /dev/null +++ b/src/transformers/integrations/eetq.py @@ -0,0 +1,121 @@ +# coding=utf-8 +# Copyright 2024 NetEase, Inc. and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ..utils import is_accelerate_available, is_eetq_available, logging + + +if is_eetq_available(): + import eetq + import torch.nn as nn + +if is_accelerate_available(): + from accelerate import init_empty_weights + +logger = logging.get_logger(__name__) + + +def _replace_with_eetq_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, + pre_quantized=False, +): + """ + Private method that wraps the recursion for module replacement. + + Returns the converted model and a boolean that indicates if the conversion has been successfull or not. + """ + if current_key_name is None: + current_key_name = [] + + for name, module in model.named_children(): + current_key_name.append(name) + + if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + if not any( + (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert + ): + with init_empty_weights(): + in_features = module.in_features + out_features = module.out_features + model._modules[name] = eetq.EetqLinear( + in_features, out_features, module.bias is not None, module.weight.device + ) + if pre_quantized: + model._modules[name].register_scale(module.weight.device) + has_been_replaced = True + + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + if len(list(module.children())) > 0: + _, has_been_replaced = _replace_with_eetq_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + pre_quantized=pre_quantized, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def replace_with_eetq_linear( + model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False +): + """ + A helper function to replace all `torch.nn.Linear` modules by `eetq.EetqLinear` modules from the `eetq` + library. This will enable running your models using high performance int8 weight-only gemm kerner from + FasterTransformer and TensorRT-LLM. Make sure `eetq` compiled with the correct CUDA + version of your hardware is installed before running this function. EETQ shall be installed via the source + 'https://github.com/NetEase-FuXi/EETQ' + + The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should + be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no + CPU/GPU memory is required to run this function. Each weight will be quantized along the channel. + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`): + Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision + for numerical stability reasons. + current_key_name (`List[`str`]`, *optional*): + An array to track the current key of the recursion. This is used to check whether the current key (part of + it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or + `disk`). + """ + + modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert + + if quantization_config.modules_to_not_convert is not None: + modules_to_not_convert.extend(quantization_config.modules_to_not_convert) + modules_to_not_convert = list(set(modules_to_not_convert)) + model, has_been_replaced = _replace_with_eetq_linear( + model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized + ) + + if not has_been_replaced: + logger.warning( + "You are loading your model using eetq but no linear modules were found in your model." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + + return model diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index 616e206a45..cc58cd7af6 100644 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -19,6 +19,7 @@ from ..utils.quantization_config import ( AqlmConfig, AwqConfig, BitsAndBytesConfig, + EetqConfig, GPTQConfig, QuantizationConfigMixin, QuantizationMethod, @@ -28,6 +29,7 @@ from .quantizer_aqlm import AqlmHfQuantizer from .quantizer_awq import AwqQuantizer from .quantizer_bnb_4bit import Bnb4BitHfQuantizer from .quantizer_bnb_8bit import Bnb8BitHfQuantizer +from .quantizer_eetq import EetqHfQuantizer from .quantizer_gptq import GptqHfQuantizer from .quantizer_quanto import QuantoHfQuantizer @@ -39,12 +41,14 @@ AUTO_QUANTIZER_MAPPING = { "gptq": GptqHfQuantizer, "aqlm": AqlmHfQuantizer, "quanto": QuantoHfQuantizer, + "eetq": EetqHfQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { "awq": AwqConfig, "bitsandbytes_4bit": BitsAndBytesConfig, "bitsandbytes_8bit": BitsAndBytesConfig, + "eetq": EetqConfig, "gptq": GPTQConfig, "aqlm": AqlmConfig, "quanto": QuantoConfig, diff --git a/src/transformers/quantizers/quantizer_eetq.py b/src/transformers/quantizers/quantizer_eetq.py new file mode 100644 index 0000000000..547037a597 --- /dev/null +++ b/src/transformers/quantizers/quantizer_eetq.py @@ -0,0 +1,170 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from .base import HfQuantizer + + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + +from ..utils import is_accelerate_available, is_eetq_available, is_torch_available, logging +from .quantizers_utils import get_module_from_name + + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +class EetqHfQuantizer(HfQuantizer): + """ + 8-bit quantization from EETQ quantization method: + before loading: converts transformer layers into W8A16Linear during loading: load 16bit weight and pass to the + layer object after: quantizes individual weights in Linear8bitLt into 8bit at first .cuda() call + """ + + requires_parameters_quantization = True + requires_calibration = False + + required_packages = ["eetq", "accelerate"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + self.quantization_config = quantization_config + + def validate_environment(self, *args, **kwargs): + if not is_eetq_available(): + raise ImportError( + "Using `eetq` 8-bit quantization requires eetq." + "Please install the latest version of eetq from : https://github.com/NetEase-FuXi/EETQ" + ) + + if not is_accelerate_available(): + raise ImportError("Loading an EETQ quantized model requires accelerate (`pip install accelerate`)") + + if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): + raise ValueError( + "Converting into 8-bit weights from tf/flax weights is currently not supported, please make" + " sure the weights are in PyTorch format." + ) + + if not torch.cuda.is_available(): + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + + device_map = kwargs.get("device_map", None) + if device_map is None: + logger.warning_once( + "You have loaded an EETQ model on CPU and have a CUDA device available, make sure to set " + "your model on a GPU device in order to run your model." + ) + elif device_map is not None: + if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()): + raise ValueError( + "You are attempting to load an EETQ model with a device_map that contains a CPU or disk device." + " This is not supported. Please remove the CPU or disk device from the device_map." + ) + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + torch_dtype = torch.float16 + logger.info( + "Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to " + "requirements of `eetq` to enable model loading in 8-bit. " + "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass" + " torch_dtype=torch.float16 to remove this warning.", + torch_dtype, + ) + elif torch_dtype != torch.float16: + logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with EETQ.") + return torch_dtype + + def check_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ): + from eetq import EetqLinear + + module, tensor_name = get_module_from_name(model, param_name) + + if isinstance(module, EetqLinear): + if self.pre_quantized or tensor_name == "bias": + if tensor_name == "weight" and param_value.dtype != torch.int8: + raise ValueError("Expect quantized weights but got an unquantized weight") + return False + else: + if tensor_name == "weight_scale": + raise ValueError("Expect unquantized weights but got a quantized weight_scale") + return True + return False + + def create_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: Optional[List[str]] = None, + ): + """ + quantizes weights into qweight and weight_scales + """ + from eetq import quantize_and_preprocess_weights + + module, tensor_name = get_module_from_name(model, param_name) + new_value, weight_scale = quantize_and_preprocess_weights(param_value) + + module._buffers[tensor_name] = new_value.to(target_device) + module.register("weight_scales", weight_scale.to(target_device)) + + def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + return model + + def _process_model_before_weight_loading( + self, + model: "PreTrainedModel", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + from ..integrations import get_keys_to_not_convert, replace_with_eetq_linear + + self.modules_to_not_convert = get_keys_to_not_convert(model) + + if self.quantization_config.modules_to_not_convert is not None: + self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) + + model = replace_with_eetq_linear( + model, + modules_to_not_convert=self.modules_to_not_convert, + quantization_config=self.quantization_config, + pre_quantized=self.pre_quantized, + ) + + model.config.quantization_config = self.quantization_config + + @property + def is_serializable(self): + return True + + @property + def is_trainable(self) -> bool: + return False diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 8297cb981e..be46d317df 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -65,6 +65,7 @@ from .utils import ( is_cython_available, is_decord_available, is_detectron2_available, + is_eetq_available, is_essentia_available, is_faiss_available, is_flash_attn_2_available, @@ -1014,6 +1015,13 @@ def require_aqlm(test_case): return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case) +def require_eetq(test_case): + """ + Decorator marking a test that requires eetq + """ + return unittest.skipUnless(is_eetq_available(), "test requires eetq")(test_case) + + def require_av(test_case): """ Decorator marking a test that requires av diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 121c4dc136..e4ff991ed7 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -119,6 +119,7 @@ from .import_utils import ( is_datasets_available, is_decord_available, is_detectron2_available, + is_eetq_available, is_essentia_available, is_faiss_available, is_flash_attn_2_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index a8c45aeac3..c65d4122b7 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -97,6 +97,7 @@ _apex_available = _is_package_available("apex") _aqlm_available = _is_package_available("aqlm") _av_available = importlib.util.find_spec("av") is not None _bitsandbytes_available = _is_package_available("bitsandbytes") +_eetq_available = _is_package_available("eetq") _galore_torch_available = _is_package_available("galore_torch") # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. _bs4_available = importlib.util.find_spec("bs4") is not None @@ -829,6 +830,10 @@ def is_auto_gptq_available(): return _auto_gptq_available +def is_eetq_available(): + return _eetq_available + + def is_levenshtein_available(): return _levenshtein_available diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index d91ecef16e..8374ddef81 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -40,6 +40,7 @@ class QuantizationMethod(str, Enum): AWQ = "awq" AQLM = "aqlm" QUANTO = "quanto" + EETQ = "eetq" class AWQLinearVersion(str, Enum): @@ -893,3 +894,37 @@ class QuantoConfig(QuantizationConfigMixin): raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}") if self.activations not in accepted_activations: raise ValueError(f"Only support weights in {accepted_activations} but found {self.activations}") + + +@dataclass +class EetqConfig(QuantizationConfigMixin): + """ + This is a wrapper class about all possible attributes and features that you can play with a model that has been + loaded using `eetq`. + + Args: + weights (`str`, *optional*, defaults to `"int8"`): + The target dtype for the weights. Supported value is only "int8" + modules_to_not_convert (`list`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have + some modules left in their original precision. + """ + + def __init__( + self, + weights: str = "int8", + modules_to_not_convert: Optional[List] = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.EETQ + self.weights = weights + self.modules_to_not_convert = modules_to_not_convert + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct + """ + accepted_weights = ["int8"] + if self.weights not in accepted_weights: + raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}") diff --git a/tests/quantization/eetq_integration/__init__.py b/tests/quantization/eetq_integration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/quantization/eetq_integration/test_eetq.py b/tests/quantization/eetq_integration/test_eetq.py new file mode 100644 index 0000000000..2c01f8145c --- /dev/null +++ b/tests/quantization/eetq_integration/test_eetq.py @@ -0,0 +1,171 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, EetqConfig, OPTForCausalLM +from transformers.testing_utils import ( + require_accelerate, + require_eetq, + require_torch_gpu, + require_torch_multi_gpu, + slow, + torch_device, +) +from transformers.utils import is_accelerate_available, is_torch_available + + +if is_torch_available(): + import torch + +if is_accelerate_available(): + from accelerate import init_empty_weights + + +@require_torch_gpu +class EetqConfigTest(unittest.TestCase): + def test_to_dict(self): + """ + Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object + """ + quantization_config = EetqConfig() + config_to_dict = quantization_config.to_dict() + + for key in config_to_dict: + self.assertEqual(getattr(quantization_config, key), config_to_dict[key]) + + def test_from_dict(self): + """ + Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict + """ + dict = {"modules_to_not_convert": ["lm_head.weight"], "quant_method": "eetq", "weights": "int8"} + quantization_config = EetqConfig.from_dict(dict) + + self.assertEqual(dict["modules_to_not_convert"], quantization_config.modules_to_not_convert) + self.assertEqual(dict["quant_method"], quantization_config.quant_method) + self.assertEqual(dict["weights"], quantization_config.weights) + + +@slow +@require_torch_gpu +@require_eetq +@require_accelerate +class EetqTest(unittest.TestCase): + model_name = "facebook/opt-350m" + + input_text = "What are we having for dinner?" + max_new_tokens = 9 + + EXPECTED_OUTPUT = "What are we having for dinner?\nI'm having a steak and a salad" + + device_map = "cuda" + + # called only once for all test in this class + @classmethod + def setUpClass(cls): + """ + Setup quantized model + """ + quantization_config = EetqConfig(weights="int8") + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.quantized_model = AutoModelForCausalLM.from_pretrained( + cls.model_name, device_map=cls.device_map, quantization_config=quantization_config + ) + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + def test_quantized_model_conversion(self): + """ + Simple test that checks if the quantized model has been converted properly + """ + from eetq import EetqLinear + + from transformers.integrations import replace_with_eetq_linear + + model_id = "facebook/opt-350m" + config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5") + quantization_config = EetqConfig(weights="int8") + + with init_empty_weights(): + model = OPTForCausalLM(config) + + nb_linears = 0 + for module in model.modules(): + if isinstance(module, torch.nn.Linear): + nb_linears += 1 + + model = replace_with_eetq_linear(model, quantization_config=quantization_config) + nb_eetq_linear = 0 + for module in model.modules(): + if isinstance(module, EetqLinear): + nb_eetq_linear += 1 + + self.assertEqual(nb_linears - 1, nb_eetq_linear) + + # Try with `linear_weights_not_to_quantize` + with init_empty_weights(): + model = OPTForCausalLM(config) + quantization_config = EetqConfig(modules_to_not_convert=["fc1"]) + model = replace_with_eetq_linear(model, quantization_config=quantization_config) + nb_eetq_linear = 0 + for module in model.modules(): + if isinstance(module, EetqLinear): + nb_eetq_linear += 1 + + self.assertEqual(nb_linears - 25, nb_eetq_linear) + + def test_quantized_model(self): + """ + Simple test that checks if the quantized model is working properly + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + def test_save_pretrained(self): + """ + Simple test that checks if the quantized model is working properly after being saved and loaded + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname) + + model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map) + + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + @require_torch_multi_gpu + def test_quantized_model_multi_gpu(self): + """ + Simple test that checks if the quantized model is working properly with multiple GPUs + set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + quantization_config = EetqConfig() + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, device_map="auto", quantization_config=quantization_config + ) + self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)