Add new quant method (#32047)
* Add new quant method * update * fix multi-device * add test * add offload * style * style * add simple example * initial doc * docstring * style again * works ? * better docs * switch to non persistant * remove print * fix init * code review
This commit is contained in:
@@ -157,6 +157,8 @@
|
||||
title: EETQ
|
||||
- local: quantization/hqq
|
||||
title: HQQ
|
||||
- local: quantization/fbgemm_fp8
|
||||
title: FBGEMM_FP8
|
||||
- local: quantization/optimum
|
||||
title: Optimum
|
||||
- local: quantization/contribute
|
||||
|
||||
@@ -56,3 +56,8 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
|
||||
## HqqConfig
|
||||
|
||||
[[autodoc]] HqqConfig
|
||||
|
||||
## FbgemmFp8Config
|
||||
|
||||
[[autodoc]] FbgemmFp8Config
|
||||
|
||||
|
||||
58
docs/source/en/quantization/fbgemm_fp8.md
Normal file
58
docs/source/en/quantization/fbgemm_fp8.md
Normal file
@@ -0,0 +1,58 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# FBGEMM FP8
|
||||
|
||||
With FBGEMM FP8 quantization method, you can quantize your model in FP8 (W8A8):
|
||||
- the weights will be quantized in 8bit (FP8) per channel
|
||||
- the activation will be quantized in 8bit (FP8) per token
|
||||
|
||||
It relies on the [FBGEMM](https://github.com/pytorch/FBGEMM) library which provides efficient low-precision general matrix multiplication for small batch sizes and support for accuracy-loss minimizing techniques such as row-wise quantization and outlier-aware quantization.
|
||||
|
||||
> [!TIP]
|
||||
> You need a GPU with compute capability>=9 (e.g. H100)
|
||||
|
||||
Before you begin, make sure the following libraries are installed with their latest version:
|
||||
|
||||
```bash
|
||||
pip install --upgrade accelerate fbgemm-gpu torch
|
||||
```
|
||||
|
||||
If you are having issues with fbgemm-gpu and torch library, you might need to install the nighlty release. You can follow the instruction [here](https://pytorch.org/FBGEMM/fbgemm_gpu-development/InstallationInstructions.html#fbgemm-gpu-install-libraries:~:text=found%20here.-,Install%20the%20FBGEMM_GPU%20Package,-Install%20through%20PyTorch)
|
||||
|
||||
|
||||
```py
|
||||
from transformers import FbgemmFp8Config, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_name = "meta-llama/Meta-Llama-3-8B"
|
||||
quantization_config = FbgemmFp8Config()
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=quantization_config)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
A quantized model can be saved via "saved_pretrained" and be reused again via the "from_pretrained".
|
||||
|
||||
```py
|
||||
quant_path = "/path/to/save/quantized/model"
|
||||
model.save_pretrained(quant_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(quant_path, device_map="auto")
|
||||
```
|
||||
@@ -55,4 +55,5 @@ Use the table below to help you decide which quantization method to use.
|
||||
| [GPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 2 - 3 - 4 - 8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ |
|
||||
| [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
|
||||
| [Quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/quanto |
|
||||
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
|
||||
|
||||
|
||||
@@ -934,6 +934,7 @@ _import_structure = {
|
||||
"AwqConfig",
|
||||
"BitsAndBytesConfig",
|
||||
"EetqConfig",
|
||||
"FbgemmFp8Config",
|
||||
"GPTQConfig",
|
||||
"HqqConfig",
|
||||
"QuantoConfig",
|
||||
@@ -5665,6 +5666,7 @@ if TYPE_CHECKING:
|
||||
AwqConfig,
|
||||
BitsAndBytesConfig,
|
||||
EetqConfig,
|
||||
FbgemmFp8Config,
|
||||
GPTQConfig,
|
||||
HqqConfig,
|
||||
QuantoConfig,
|
||||
|
||||
@@ -45,6 +45,7 @@ _import_structure = {
|
||||
"unset_hf_deepspeed_config",
|
||||
],
|
||||
"eetq": ["replace_with_eetq_linear"],
|
||||
"fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"],
|
||||
"ggml": [
|
||||
"GGUF_CONFIG_MAPPING",
|
||||
"GGUF_TENSOR_MAPPING",
|
||||
@@ -126,6 +127,7 @@ if TYPE_CHECKING:
|
||||
unset_hf_deepspeed_config,
|
||||
)
|
||||
from .eetq import replace_with_eetq_linear
|
||||
from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear
|
||||
from .ggml import (
|
||||
GGUF_CONFIG_MAPPING,
|
||||
GGUF_TENSOR_MAPPING,
|
||||
|
||||
161
src/transformers/integrations/fbgemm_fp8.py
Normal file
161
src/transformers/integrations/fbgemm_fp8.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
if is_fbgemm_gpu_available():
|
||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FbgemmFp8Linear(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.register_buffer("weight", torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
|
||||
self.register_buffer("weight_scale", torch.zeros((out_features, 1), dtype=weight_dtype))
|
||||
self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
|
||||
|
||||
if bias:
|
||||
self.register_buffer("bias", torch.zeros((self.out_features), dtype=weight_dtype))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, x):
|
||||
num_tokens = None
|
||||
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
|
||||
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
|
||||
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||
x.view(-1, x.shape[-1]), num_tokens, self.input_scale_ub
|
||||
)
|
||||
# moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
|
||||
# x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
|
||||
|
||||
# The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
|
||||
output = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||
x_quantized, self.weight, x_scale, self.weight_scale, use_fast_accum=True
|
||||
)
|
||||
output = output + self.bias if self.bias is not None else output
|
||||
# Hacky for now, we have the output to the device of x
|
||||
output = output.to(x.device)
|
||||
del x_quantized, x_scale
|
||||
return output
|
||||
|
||||
|
||||
def _replace_with_fbgemm_fp8_linear(
|
||||
model,
|
||||
modules_to_not_convert=None,
|
||||
current_key_name=None,
|
||||
quantization_config=None,
|
||||
has_been_replaced=False,
|
||||
pre_quantized=False,
|
||||
):
|
||||
"""
|
||||
Private method that wraps the recursion for module replacement.
|
||||
|
||||
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
|
||||
"""
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
|
||||
for name, module in model.named_children():
|
||||
current_key_name.append(name)
|
||||
|
||||
if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert:
|
||||
# Check if the current key is not in the `modules_to_not_convert`
|
||||
current_key_name_str = ".".join(current_key_name)
|
||||
if not any(
|
||||
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
|
||||
):
|
||||
with init_empty_weights(include_buffers=True):
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
model._modules[name] = FbgemmFp8Linear(
|
||||
in_features,
|
||||
out_features,
|
||||
module.bias is not None,
|
||||
)
|
||||
has_been_replaced = True
|
||||
|
||||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False)
|
||||
# set non persistant buffer outside of init_empty_weights
|
||||
model._modules[name].input_scale_ub = torch.tensor(
|
||||
[quantization_config.activation_scale_ub], dtype=torch.float
|
||||
)
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_fbgemm_fp8_linear(
|
||||
module,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
has_been_replaced=has_been_replaced,
|
||||
pre_quantized=pre_quantized,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
current_key_name.pop(-1)
|
||||
return model, has_been_replaced
|
||||
|
||||
|
||||
def replace_with_fbgemm_fp8_linear(
|
||||
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False
|
||||
):
|
||||
"""
|
||||
A helper function to replace all `torch.nn.Linear` modules by `FbgemmFp8Linear` modules.
|
||||
This will enable running your models using high performance fp8 kernel from FBGEMM library.
|
||||
|
||||
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
|
||||
be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
|
||||
CPU/GPU memory is required to run this function. Each weight will be quantized along the channel.
|
||||
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model or `torch.nn.Module` as the function is run recursively.
|
||||
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
|
||||
Names of the modules to not convert in `FP8Linear`. In practice we keep the `lm_head` in full precision
|
||||
for numerical stability reasons.
|
||||
current_key_name (`List[`str`]`, *optional*):
|
||||
An array to track the current key of the recursion. This is used to check whether the current key (part of
|
||||
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
|
||||
`disk`).
|
||||
"""
|
||||
|
||||
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
|
||||
|
||||
if quantization_config.modules_to_not_convert is not None:
|
||||
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
|
||||
modules_to_not_convert = list(set(modules_to_not_convert))
|
||||
model, has_been_replaced = _replace_with_fbgemm_fp8_linear(
|
||||
model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized
|
||||
)
|
||||
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading your model using FP8 quantization but no linear modules were found in your model."
|
||||
" Please double check your model architecture, or submit an issue on github if you think this is"
|
||||
" a bug."
|
||||
)
|
||||
|
||||
return model
|
||||
@@ -868,7 +868,7 @@ def _load_state_dict_into_meta_model(
|
||||
|
||||
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
|
||||
# in int/uint/bool and not cast them.
|
||||
if dtype is not None and torch.is_floating_point(param):
|
||||
if dtype is not None and torch.is_floating_point(param) and param.dtype != torch.float8_e4m3fn:
|
||||
if (
|
||||
keep_in_fp32_modules is not None
|
||||
and any(
|
||||
@@ -894,7 +894,6 @@ def _load_state_dict_into_meta_model(
|
||||
old_param = getattr(old_param, split)
|
||||
if old_param is None:
|
||||
break
|
||||
|
||||
if old_param is not None:
|
||||
if dtype is None:
|
||||
param = param.to(old_param.dtype)
|
||||
@@ -3955,6 +3954,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
|
||||
):
|
||||
device_map_kwargs["force_hooks"] = True
|
||||
if (
|
||||
hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.FBGEMM_FP8
|
||||
and isinstance(device_map, dict)
|
||||
and ("cpu" in device_map.values() or "disk" in device_map.values())
|
||||
):
|
||||
device_map_kwargs["offload_buffers"] = True
|
||||
|
||||
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
|
||||
dispatch_model(model, **device_map_kwargs)
|
||||
|
||||
@@ -4105,7 +4112,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if cls._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if hf_quantizer is not None:
|
||||
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from ..utils.quantization_config import (
|
||||
AwqConfig,
|
||||
BitsAndBytesConfig,
|
||||
EetqConfig,
|
||||
FbgemmFp8Config,
|
||||
GPTQConfig,
|
||||
HqqConfig,
|
||||
QuantizationConfigMixin,
|
||||
@@ -31,6 +32,7 @@ from .quantizer_awq import AwqQuantizer
|
||||
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
|
||||
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
|
||||
from .quantizer_eetq import EetqHfQuantizer
|
||||
from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
|
||||
from .quantizer_gptq import GptqHfQuantizer
|
||||
from .quantizer_hqq import HqqHfQuantizer
|
||||
from .quantizer_quanto import QuantoHfQuantizer
|
||||
@@ -45,6 +47,7 @@ AUTO_QUANTIZER_MAPPING = {
|
||||
"quanto": QuantoHfQuantizer,
|
||||
"eetq": EetqHfQuantizer,
|
||||
"hqq": HqqHfQuantizer,
|
||||
"fbgemm_fp8": FbgemmFp8HfQuantizer,
|
||||
}
|
||||
|
||||
AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
@@ -56,6 +59,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
"aqlm": AqlmConfig,
|
||||
"quanto": QuantoConfig,
|
||||
"hqq": HqqConfig,
|
||||
"fbgemm_fp8": FbgemmFp8Config,
|
||||
}
|
||||
|
||||
|
||||
@@ -156,8 +160,11 @@ class AutoHfQuantizer:
|
||||
if isinstance(quantization_config, dict):
|
||||
quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
|
||||
|
||||
if isinstance(quantization_config, (GPTQConfig, AwqConfig)) and quantization_config_from_args is not None:
|
||||
# special case for GPTQ / AWQ config collision
|
||||
if (
|
||||
isinstance(quantization_config, (GPTQConfig, AwqConfig, FbgemmFp8Config))
|
||||
and quantization_config_from_args is not None
|
||||
):
|
||||
# special case for GPTQ / AWQ / FbgemmFp8 config collision
|
||||
loading_attr_dict = quantization_config_from_args.get_loading_attributes()
|
||||
for attr, val in loading_attr_dict.items():
|
||||
setattr(quantization_config, attr, val)
|
||||
|
||||
205
src/transformers/quantizers/quantizer_fbgemm_fp8.py
Normal file
205
src/transformers/quantizers/quantizer_fbgemm_fp8.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from packaging import version
|
||||
|
||||
from .base import HfQuantizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
|
||||
from .quantizers_utils import get_module_from_name
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FbgemmFp8HfQuantizer(HfQuantizer):
|
||||
"""
|
||||
FP8 quantization using fbgemm kernels
|
||||
"""
|
||||
|
||||
requires_parameters_quantization = True
|
||||
requires_calibration = False
|
||||
|
||||
required_packages = ["fbgemm-gpu", "accelerate"]
|
||||
|
||||
def __init__(self, quantization_config, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
self.quantization_config = quantization_config
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not is_torch_available() or version.parse(importlib.metadata.version("torch")) < version.parse("2.1.0"):
|
||||
raise ImportError(
|
||||
"Using fbgemm fp8 quantization requires torch > 2.1.0"
|
||||
"Please install the latest version of torch ( pip install --upgrade torch )"
|
||||
)
|
||||
if not is_fbgemm_gpu_available():
|
||||
raise ImportError(
|
||||
"Using fbgemm fp8 quantization requires fbgemm-gpu library"
|
||||
"Please install the latest version of fbgemm-gpu library by following : https://pytorch.org/FBGEMM/fbgemm_gpu-development/InstallationInstructions.html#fbgemm-gpu-install-libraries"
|
||||
)
|
||||
|
||||
if not is_accelerate_available("0.32.2"):
|
||||
raise ImportError(
|
||||
"Loading an FP8 quantized model requires accelerate > 0.32.1 (`pip install --upgrade accelerate`)"
|
||||
)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("Using FP8 quantized models with fbgemm kernels requires a GPU")
|
||||
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
major, minor = compute_capability
|
||||
if major < 9:
|
||||
raise ValueError(
|
||||
"FP8 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)"
|
||||
)
|
||||
|
||||
device_map = kwargs.get("device_map", None)
|
||||
if device_map is None:
|
||||
logger.warning_once(
|
||||
"You have loaded an FP8 model on CPU and have a CUDA device available, make sure to set "
|
||||
"your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. "
|
||||
)
|
||||
elif device_map is not None:
|
||||
if (
|
||||
not self.pre_quantized
|
||||
and isinstance(device_map, dict)
|
||||
and ("cpu" in device_map.values() or "disk" in device_map.values())
|
||||
):
|
||||
raise ValueError(
|
||||
"You are attempting to load an FP8 model with a device_map that contains a CPU or disk device."
|
||||
"This is not supported when the model is quantized on the fly. "
|
||||
"Please use a quantized checkpoint or remove the CPU or disk device from the device_map."
|
||||
)
|
||||
|
||||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if torch_dtype is None:
|
||||
torch_dtype = torch.bfloat16
|
||||
logger.info(
|
||||
"Overriding torch_dtype=%s with `torch_dtype=torch.bloat16` due to "
|
||||
"requirements of `fbgemm-gpu` to enable model loading in fp8. "
|
||||
"Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
|
||||
" torch_dtype=torch.bfloat16 to remove this warning.",
|
||||
torch_dtype,
|
||||
)
|
||||
elif torch_dtype == torch.float16:
|
||||
raise ValueError(
|
||||
"You cannot use FP8 with torch_dtype=torch.float16."
|
||||
"We recommend you passing torch_dtype=torch.bfloat16"
|
||||
)
|
||||
return torch_dtype
|
||||
|
||||
def check_quantized_param(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
from ..integrations import FbgemmFp8Linear
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
if isinstance(module, FbgemmFp8Linear):
|
||||
if self.pre_quantized or tensor_name == "bias":
|
||||
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
|
||||
raise ValueError("Expect quantized weights but got an unquantized weight")
|
||||
return False
|
||||
else:
|
||||
if tensor_name == "weight_scale":
|
||||
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
|
||||
return True
|
||||
return False
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: Dict[str, Any],
|
||||
unexpected_keys: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Quantizes weights into weight and weight_scale
|
||||
"""
|
||||
new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(param_value)
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
module._buffers[tensor_name] = new_value.to(target_device)
|
||||
# to have the right output shape -> (out_features, 1)
|
||||
module._buffers["weight_scale"] = weight_scale.view(weight_scale.shape[0], 1).to(target_device)
|
||||
|
||||
if unexpected_keys is not None and param_name in unexpected_keys:
|
||||
unexpected_keys.remove(param_name)
|
||||
del param_name
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
return model
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
device_map,
|
||||
keep_in_fp32_modules: List[str] = [],
|
||||
**kwargs,
|
||||
):
|
||||
from ..integrations import get_keys_to_not_convert, replace_with_fbgemm_fp8_linear
|
||||
|
||||
self.modules_to_not_convert = get_keys_to_not_convert(model)
|
||||
|
||||
if self.quantization_config.modules_to_not_convert is not None:
|
||||
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
|
||||
|
||||
model = replace_with_fbgemm_fp8_linear(
|
||||
model,
|
||||
modules_to_not_convert=self.modules_to_not_convert,
|
||||
quantization_config=self.quantization_config,
|
||||
pre_quantized=self.pre_quantized,
|
||||
)
|
||||
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
|
||||
from ..integrations import FbgemmFp8Linear
|
||||
|
||||
not_missing_keys = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, FbgemmFp8Linear):
|
||||
for missing in missing_keys:
|
||||
if (
|
||||
(name in missing or name in f"{prefix}.{missing}")
|
||||
and not missing.endswith(".weight")
|
||||
and not missing.endswith(".bias")
|
||||
):
|
||||
not_missing_keys.append(missing)
|
||||
return [k for k in missing_keys if k not in not_missing_keys]
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
return False
|
||||
@@ -68,6 +68,7 @@ from .utils import (
|
||||
is_eetq_available,
|
||||
is_essentia_available,
|
||||
is_faiss_available,
|
||||
is_fbgemm_gpu_available,
|
||||
is_flash_attn_2_available,
|
||||
is_flax_available,
|
||||
is_fsdp_available,
|
||||
@@ -1116,6 +1117,13 @@ def require_quanto(test_case):
|
||||
return unittest.skipUnless(is_quanto_available(), "test requires quanto")(test_case)
|
||||
|
||||
|
||||
def require_fbgemm_gpu(test_case):
|
||||
"""
|
||||
Decorator for fbgemm_gpu dependency
|
||||
"""
|
||||
return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case)
|
||||
|
||||
|
||||
def require_phonemizer(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires phonemizer
|
||||
|
||||
@@ -127,6 +127,7 @@ from .import_utils import (
|
||||
is_eetq_available,
|
||||
is_essentia_available,
|
||||
is_faiss_available,
|
||||
is_fbgemm_gpu_available,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
|
||||
@@ -98,6 +98,7 @@ _aqlm_available = _is_package_available("aqlm")
|
||||
_av_available = importlib.util.find_spec("av") is not None
|
||||
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
||||
_eetq_available = _is_package_available("eetq")
|
||||
_fbgemm_gpu_available = _is_package_available("fbgemm_gpu")
|
||||
_galore_torch_available = _is_package_available("galore_torch")
|
||||
_lomo_available = _is_package_available("lomo_optim")
|
||||
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
||||
@@ -888,6 +889,10 @@ def is_eetq_available():
|
||||
return _eetq_available
|
||||
|
||||
|
||||
def is_fbgemm_gpu_available():
|
||||
return _fbgemm_gpu_available
|
||||
|
||||
|
||||
def is_levenshtein_available():
|
||||
return _levenshtein_available
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ class QuantizationMethod(str, Enum):
|
||||
QUANTO = "quanto"
|
||||
EETQ = "eetq"
|
||||
HQQ = "hqq"
|
||||
FBGEMM_FP8 = "fbgemm_fp8"
|
||||
|
||||
|
||||
class AWQLinearVersion(str, Enum):
|
||||
@@ -1047,3 +1048,34 @@ class EetqConfig(QuantizationConfigMixin):
|
||||
accepted_weights = ["int8"]
|
||||
if self.weights not in accepted_weights:
|
||||
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FbgemmFp8Config(QuantizationConfigMixin):
|
||||
"""
|
||||
This is a wrapper class about all possible attributes and features that you can play with a model that has been
|
||||
loaded using fbgemm fp8 quantization.
|
||||
|
||||
Args:
|
||||
activation_scale_ub (`float`, *optional*, defaults to 1200.0):
|
||||
The activation scale upper bound. This is used when quantizing the input activation.
|
||||
modules_to_not_convert (`list`, *optional*, default to `None`):
|
||||
The list of modules to not quantize, useful for quantizing models that explicitly require to have
|
||||
some modules left in their original precision.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation_scale_ub: float = 1200.0,
|
||||
modules_to_not_convert: Optional[List] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.FBGEMM_FP8
|
||||
self.activation_scale_ub = activation_scale_ub
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
def get_loading_attributes(self):
|
||||
attibutes_dict = copy.deepcopy(self.__dict__)
|
||||
loading_attibutes = ["activation_scale_ub"]
|
||||
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
|
||||
return loading_attibutes_dict
|
||||
|
||||
0
tests/quantization/fbgemm_fp8/__init__.py
Normal file
0
tests/quantization/fbgemm_fp8/__init__.py
Normal file
270
tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py
Normal file
270
tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py
Normal file
@@ -0,0 +1,270 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FbgemmFp8Config, OPTForCausalLM
|
||||
from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_fbgemm_gpu,
|
||||
require_read_token,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_accelerate_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class FbgemmFp8ConfigTest(unittest.TestCase):
|
||||
def test_to_dict(self):
|
||||
"""
|
||||
Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object
|
||||
"""
|
||||
quantization_config = FbgemmFp8Config()
|
||||
config_to_dict = quantization_config.to_dict()
|
||||
|
||||
for key in config_to_dict:
|
||||
self.assertEqual(getattr(quantization_config, key), config_to_dict[key])
|
||||
|
||||
def test_from_dict(self):
|
||||
"""
|
||||
Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict
|
||||
"""
|
||||
dict = {"modules_to_not_convert": ["lm_head.weight"], "quant_method": "fbgemm_fp8"}
|
||||
quantization_config = FbgemmFp8Config.from_dict(dict)
|
||||
|
||||
self.assertEqual(dict["modules_to_not_convert"], quantization_config.modules_to_not_convert)
|
||||
self.assertEqual(dict["quant_method"], quantization_config.quant_method)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_fbgemm_gpu
|
||||
@require_accelerate
|
||||
@require_read_token
|
||||
class FbgemmFp8Test(unittest.TestCase):
|
||||
model_name = "meta-llama/Meta-Llama-3-8B"
|
||||
|
||||
input_text = "What are we having for dinner?"
|
||||
max_new_tokens = 9
|
||||
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\nI'm having a steak and a salad"
|
||||
|
||||
device_map = "cuda"
|
||||
|
||||
offload_device_map = {
|
||||
"model.embed_tokens": 0,
|
||||
"model.layers.0": 0,
|
||||
"model.layers.1": 0,
|
||||
"model.layers.2": 0,
|
||||
"model.layers.3": 0,
|
||||
"model.layers.4": 0,
|
||||
"model.layers.5": 0,
|
||||
"model.layers.6": 0,
|
||||
"model.layers.7": 0,
|
||||
"model.layers.8": 0,
|
||||
"model.layers.9": 0,
|
||||
"model.layers.10": 0,
|
||||
"model.layers.11": 0,
|
||||
"model.layers.12": 0,
|
||||
"model.layers.13": 0,
|
||||
"model.layers.14": 0,
|
||||
"model.layers.15": 0,
|
||||
"model.layers.16": "cpu",
|
||||
"model.layers.17": "cpu",
|
||||
"model.layers.18": "cpu",
|
||||
"model.layers.19": "cpu",
|
||||
"model.layers.20": "disk",
|
||||
"model.layers.21": "disk",
|
||||
"model.layers.22": "disk",
|
||||
"model.layers.23": "disk",
|
||||
"model.layers.24": "disk",
|
||||
"model.layers.25": "disk",
|
||||
"model.layers.26": "disk",
|
||||
"model.layers.27": "disk",
|
||||
"model.layers.28": "disk",
|
||||
"model.layers.29": "disk",
|
||||
"model.layers.30": "disk",
|
||||
"model.layers.31": "disk",
|
||||
"model.norm": "disk",
|
||||
"lm_head": "disk",
|
||||
}
|
||||
|
||||
# called only once for all test in this class
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
Setup quantized model
|
||||
"""
|
||||
quantization_config = FbgemmFp8Config()
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
cls.model_name, device_map=cls.device_map, quantization_config=quantization_config
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def test_quantized_model_conversion(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model has been converted properly
|
||||
"""
|
||||
|
||||
from transformers.integrations import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear
|
||||
|
||||
model_id = "facebook/opt-350m"
|
||||
config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5")
|
||||
quantization_config = FbgemmFp8Config()
|
||||
|
||||
with init_empty_weights():
|
||||
model = OPTForCausalLM(config)
|
||||
|
||||
nb_linears = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
nb_linears += 1
|
||||
|
||||
model = replace_with_fbgemm_fp8_linear(model, quantization_config=quantization_config)
|
||||
nb_fbgemm_linear = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, FbgemmFp8Linear):
|
||||
nb_fbgemm_linear += 1
|
||||
|
||||
self.assertEqual(nb_linears - 1, nb_fbgemm_linear)
|
||||
|
||||
with init_empty_weights():
|
||||
model = OPTForCausalLM(config)
|
||||
quantization_config = FbgemmFp8Config(modules_to_not_convert=["fc1"])
|
||||
model = replace_with_fbgemm_fp8_linear(model, quantization_config=quantization_config)
|
||||
nb_fbgemm_linear = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, FbgemmFp8Linear):
|
||||
nb_fbgemm_linear += 1
|
||||
|
||||
self.assertEqual(nb_linears - 25, nb_fbgemm_linear)
|
||||
|
||||
def test_quantized_model(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_save_pretrained(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly after being saved and loaded
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.quantized_model.save_pretrained(tmpdirname)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
|
||||
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_change_loading_attributes(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly after being saved and loaded
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.quantized_model.save_pretrained(tmpdirname)
|
||||
|
||||
quantization_config = FbgemmFp8Config(activation_scale_ub=1000.0)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
tmpdirname, device_map=self.device_map, quantization_config=quantization_config
|
||||
)
|
||||
|
||||
self.assertEqual(model.model.layers[1].mlp.down_proj.input_scale_ub.item(), 1000.0)
|
||||
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_quantized_model_multi_gpu(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly with multiple GPUs
|
||||
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
quantization_config = FbgemmFp8Config()
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, device_map="auto", quantization_config=quantization_config
|
||||
)
|
||||
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_quantized_model_offload(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model returns an error when loading with cpu/disk offloaded
|
||||
"""
|
||||
quantization_config = FbgemmFp8Config()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "You are attempting to load an FP8 model with a device_map that contains a CPU or disk device."
|
||||
):
|
||||
AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, device_map=self.offload_device_map, quantization_config=quantization_config
|
||||
)
|
||||
|
||||
def test_save_pretrained_offload(self):
|
||||
"""
|
||||
Simple test that checks if the saved quantized model is working properly cpu/disk offload
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.quantized_model.save_pretrained(tmpdirname)
|
||||
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.offload_device_map)
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_save_pretrained_multi_gpu(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly after being saved and loaded
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.quantized_model.save_pretrained(tmpdirname)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="auto")
|
||||
self.assertTrue(set(model.hf_device_map.values()) == {0, 1})
|
||||
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
Reference in New Issue
Block a user