Support AOPerModuleConfig and include_embedding (#37802)
* Support `AOPerModuleConfig` and include_embedding Summary: This PR adds support per module configuration for torchao Also added per module quantization examples: 1. Quantizing different layers with different quantization configs 2. Skip quantization for certain layers Test Plan: python tests/quantization/torchao_integration/test_torchao.py -k test_include_embedding python tests/quantization/torchao_integration/test_torchao.py -k test_per_module_config_skip Reviewers: Subscribers: Tasks: Tags: * format * format * inlcude embedding remove input embedding from module not to convert * more docs * Update docs/source/en/quantization/torchao.md Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> * Update src/transformers/quantizers/quantizer_torchao.py Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> * Update src/transformers/quantizers/quantizer_torchao.py Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
@@ -40,6 +40,8 @@ torchao supports the [quantization techniques](https://github.com/pytorch/ao/blo
|
|||||||
- A16W4 Int4 Weight Only Quantization
|
- A16W4 Int4 Weight Only Quantization
|
||||||
- Autoquantization
|
- Autoquantization
|
||||||
|
|
||||||
|
torchao also supports module level configuration by specifying a dictionary from fully qualified name of module and its corresponding quantization config. This allows skip quantizing certain layers and using different quantization config for different modules.
|
||||||
|
|
||||||
|
|
||||||
Check the table below to see if your hardware is compatible.
|
Check the table below to see if your hardware is compatible.
|
||||||
|
|
||||||
@@ -89,7 +91,7 @@ We'll show examples for recommended quantization methods based on hardwares, e.g
|
|||||||
```py
|
```py
|
||||||
import torch
|
import torch
|
||||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||||
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
|
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig
|
||||||
|
|
||||||
quant_config = Float8DynamicActivationFloat8WeightConfig()
|
quant_config = Float8DynamicActivationFloat8WeightConfig()
|
||||||
# or float8 weight only quantization
|
# or float8 weight only quantization
|
||||||
@@ -149,7 +151,7 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
|
|||||||
```py
|
```py
|
||||||
import torch
|
import torch
|
||||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||||
from torchao.quantization import Int8DynamicActivationInt8WeightConfig
|
from torchao.quantization import Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig
|
||||||
|
|
||||||
quant_config = Int8DynamicActivationInt8WeightConfig()
|
quant_config = Int8DynamicActivationInt8WeightConfig()
|
||||||
# or int8 weight only quantization
|
# or int8 weight only quantization
|
||||||
@@ -179,7 +181,7 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
|
|||||||
```py
|
```py
|
||||||
import torch
|
import torch
|
||||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||||
from torchao.quantization import GemliteUIntXWeightOnlyConfig
|
from torchao.quantization import GemliteUIntXWeightOnlyConfig, Int4WeightOnlyConfig
|
||||||
|
|
||||||
# For batch size N, we recommend gemlite, which may require autotuning
|
# For batch size N, we recommend gemlite, which may require autotuning
|
||||||
# default is 4 bit, 8 bit is also supported by passing `bit_width=8`
|
# default is 4 bit, 8 bit is also supported by passing `bit_width=8`
|
||||||
@@ -216,7 +218,7 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
|
|||||||
```py
|
```py
|
||||||
import torch
|
import torch
|
||||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||||
from torchao.quantization import Int8DynamicActivationInt8WeightConfig
|
from torchao.quantization import Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig
|
||||||
|
|
||||||
quant_config = Int8DynamicActivationInt8WeightConfig()
|
quant_config = Int8DynamicActivationInt8WeightConfig()
|
||||||
# quant_config = Int8WeightOnlyConfig()
|
# quant_config = Int8WeightOnlyConfig()
|
||||||
@@ -272,6 +274,74 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
|
|||||||
</hfoption>
|
</hfoption>
|
||||||
</hfoptions>
|
</hfoptions>
|
||||||
|
|
||||||
|
### Per Module Quantization
|
||||||
|
#### 1. Skip quantization for certain layers
|
||||||
|
With `AOPerModuleConfig` we can specify a default configuration for all layers while skipping quantization for certain layers.
|
||||||
|
```py
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||||
|
|
||||||
|
model_id = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
|
from torchao.quantization import Int4WeightOnlyConfig, AOPerModuleConfig
|
||||||
|
config = Int4WeightOnlyConfig(group_size=128)
|
||||||
|
|
||||||
|
# set default to int4 (for linears), and skip quantizing `model.layers.0.self_attn.q_proj`
|
||||||
|
quant_config = AOPerModuleConfig({"_default": config, "model.layers.0.self_attn.q_proj": None})
|
||||||
|
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||||
|
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
|
||||||
|
# lm_head is not quantized and model.layers.0.self_attn.q_proj is not quantized
|
||||||
|
print("quantized model:", quantized_model)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
|
# Manual Testing
|
||||||
|
prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||||
|
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
|
||||||
|
output_text = tokenizer.batch_decode(
|
||||||
|
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)
|
||||||
|
print(output_text)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Quantizing different layers with different quantization configs
|
||||||
|
```py
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||||
|
|
||||||
|
model_id = "facebook/opt-125m"
|
||||||
|
|
||||||
|
from torchao.quantization import Int4WeightOnlyConfig, AOPerModuleConfig, Int8DynamicActivationInt4WeightConfig, IntxWeightOnlyConfig, PerAxis, MappingType
|
||||||
|
|
||||||
|
weight_dtype = torch.int8
|
||||||
|
granularity = PerAxis(0)
|
||||||
|
mapping_type = MappingType.ASYMMETRIC
|
||||||
|
embedding_config = IntxWeightOnlyConfig(
|
||||||
|
weight_dtype=weight_dtype,
|
||||||
|
granularity=granularity,
|
||||||
|
mapping_type=mapping_type,
|
||||||
|
)
|
||||||
|
linear_config = Int8DynamicActivationInt4WeightConfig(group_size=128)
|
||||||
|
quant_config = AOPerModuleConfig({"_default": linear_config, "model.decoder.embed_tokens": embedding_config, "model.decoder.embed_positions": None})
|
||||||
|
# set `include_embedding` to True in order to include embedding in quantization
|
||||||
|
# when `include_embedding` is True, we'll remove input embedding from `modules_not_to_convert` as well
|
||||||
|
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True)
|
||||||
|
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
|
||||||
|
print("quantized model:", quantized_model)
|
||||||
|
# make sure embedding is quantized
|
||||||
|
print("embed_tokens weight:", quantized_model.model.decoder.embed_tokens.weight)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
|
# Manual Testing
|
||||||
|
prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
|
||||||
|
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128, cache_implementation="static")
|
||||||
|
output_text = tokenizer.batch_decode(
|
||||||
|
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)
|
||||||
|
print(output_text)
|
||||||
|
```
|
||||||
|
|
||||||
### Autoquant
|
### Autoquant
|
||||||
|
|
||||||
If you want to automatically choose a quantization type for quantizable layers (`nn.Linear`) you can use the [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API.
|
If you want to automatically choose a quantization type for quantizable layers (`nn.Linear`) you can use the [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API.
|
||||||
|
|||||||
@@ -185,6 +185,10 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|||||||
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
||||||
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
|
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
|
||||||
)
|
)
|
||||||
|
if self.quantization_config.include_embedding:
|
||||||
|
input_emb = model.get_input_embeddings()
|
||||||
|
input_emb_names = [name for name, module in model.named_modules() if id(module) == id(input_emb)]
|
||||||
|
self.modules_to_not_convert = [x for x in self.modules_to_not_convert if x not in input_emb_names]
|
||||||
return
|
return
|
||||||
|
|
||||||
def check_quantized_param(
|
def check_quantized_param(
|
||||||
@@ -206,9 +210,12 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|||||||
# We don't quantize weights that we offload
|
# We don't quantize weights that we offload
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
# we only quantize the weight of nn.Linear
|
# we only quantize the weight of nn.Linear and nn.Embedding
|
||||||
module, tensor_name = get_module_from_name(model, param_name)
|
module, tensor_name = get_module_from_name(model, param_name)
|
||||||
return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
|
_QUANTIZABLE = [torch.nn.Linear]
|
||||||
|
if self.quantization_config.include_embedding:
|
||||||
|
_QUANTIZABLE.append(torch.nn.Embedding)
|
||||||
|
return isinstance(module, tuple(_QUANTIZABLE)) and (tensor_name == "weight")
|
||||||
|
|
||||||
def create_quantized_param(
|
def create_quantized_param(
|
||||||
self,
|
self,
|
||||||
@@ -240,6 +247,23 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|||||||
module._parameters[tensor_name] = torch.nn.Parameter(
|
module._parameters[tensor_name] = torch.nn.Parameter(
|
||||||
param_value, requires_grad=param_value.requires_grad
|
param_value, requires_grad=param_value.requires_grad
|
||||||
).to(device=target_device)
|
).to(device=target_device)
|
||||||
|
# handle AOPerModuleConfig, introduced in torchao 0.11.0+
|
||||||
|
if self.quantization_config._get_ao_version() > version.Version("0.10.0"):
|
||||||
|
from torchao.quantization import AOPerModuleConfig
|
||||||
|
|
||||||
|
config = self.quantization_config.get_apply_tensor_subclass()
|
||||||
|
if isinstance(config, AOPerModuleConfig):
|
||||||
|
module_fqn, _ = param_name.rsplit(".", 1)
|
||||||
|
c = None
|
||||||
|
if module_fqn in config.module_fqn_to_config:
|
||||||
|
c = config.module_fqn_to_config[module_fqn]
|
||||||
|
else:
|
||||||
|
c = config.module_fqn_to_config.get("_default", None)
|
||||||
|
if c is not None:
|
||||||
|
# filter_fn: not filtering out any modules
|
||||||
|
quantize_(module, c, filter_fn=lambda x, fqn: True)
|
||||||
|
return
|
||||||
|
|
||||||
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
|
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
|
||||||
|
|
||||||
def _process_model_after_weight_loading(self, model, **kwargs):
|
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||||
|
|||||||
@@ -1554,6 +1554,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
|||||||
quant_type: Union[str, "AOBaseConfig"] # noqa: F821
|
quant_type: Union[str, "AOBaseConfig"] # noqa: F821
|
||||||
modules_to_not_convert: Optional[List]
|
modules_to_not_convert: Optional[List]
|
||||||
quant_type_kwargs: Dict[str, Any]
|
quant_type_kwargs: Dict[str, Any]
|
||||||
|
include_embedding: bool
|
||||||
|
|
||||||
"""This is a config class for torchao quantization/sparsity techniques.
|
"""This is a config class for torchao quantization/sparsity techniques.
|
||||||
|
|
||||||
@@ -1565,6 +1566,9 @@ class TorchAoConfig(QuantizationConfigMixin):
|
|||||||
modules_to_not_convert (`list`, *optional*, default to `None`):
|
modules_to_not_convert (`list`, *optional*, default to `None`):
|
||||||
The list of modules to not quantize, useful for quantizing models that explicitly require to have
|
The list of modules to not quantize, useful for quantizing models that explicitly require to have
|
||||||
some modules left in their original precision.
|
some modules left in their original precision.
|
||||||
|
inlcude_embedding (`bool`, default to `False`):
|
||||||
|
Whether to include embedding in quantization or not, input embedding will be removed from
|
||||||
|
the module_not_to_convert list as well if this flag is set.
|
||||||
kwargs (`Dict[str, Any]`, *optional*):
|
kwargs (`Dict[str, Any]`, *optional*):
|
||||||
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments
|
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments
|
||||||
`group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in
|
`group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in
|
||||||
@@ -1609,12 +1613,14 @@ class TorchAoConfig(QuantizationConfigMixin):
|
|||||||
self,
|
self,
|
||||||
quant_type: Union[str, "AOBaseConfig"], # noqa: F821
|
quant_type: Union[str, "AOBaseConfig"], # noqa: F821
|
||||||
modules_to_not_convert: Optional[List] = None,
|
modules_to_not_convert: Optional[List] = None,
|
||||||
|
include_embedding: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.quant_method = QuantizationMethod.TORCHAO
|
self.quant_method = QuantizationMethod.TORCHAO
|
||||||
self.quant_type = quant_type
|
self.quant_type = quant_type
|
||||||
self.modules_to_not_convert = modules_to_not_convert
|
self.modules_to_not_convert = modules_to_not_convert
|
||||||
self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs)
|
self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs)
|
||||||
|
self.include_embedding = include_embedding
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -38,6 +38,13 @@ if is_torchao_available():
|
|||||||
AffineQuantizedTensor,
|
AffineQuantizedTensor,
|
||||||
TensorCoreTiledLayout,
|
TensorCoreTiledLayout,
|
||||||
)
|
)
|
||||||
|
from torchao.quantization import (
|
||||||
|
AOPerModuleConfig,
|
||||||
|
Int8WeightOnlyConfig,
|
||||||
|
IntxWeightOnlyConfig,
|
||||||
|
MappingType,
|
||||||
|
PerAxis,
|
||||||
|
)
|
||||||
from torchao.quantization.autoquant import AQMixin
|
from torchao.quantization.autoquant import AQMixin
|
||||||
|
|
||||||
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0"):
|
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0"):
|
||||||
@@ -193,6 +200,60 @@ class TorchAoTest(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
|
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
@require_torchao_version_greater_or_equal("0.11.0")
|
||||||
|
def test_include_embedding(self):
|
||||||
|
weight_dtype = torch.int8
|
||||||
|
granularity = PerAxis(0)
|
||||||
|
mapping_type = MappingType.ASYMMETRIC
|
||||||
|
embedding_config = IntxWeightOnlyConfig(
|
||||||
|
weight_dtype=weight_dtype,
|
||||||
|
granularity=granularity,
|
||||||
|
mapping_type=mapping_type,
|
||||||
|
)
|
||||||
|
config = AOPerModuleConfig({"_default": None, "model.embed_tokens": embedding_config})
|
||||||
|
# need set `include_embedding` to True
|
||||||
|
quant_config = TorchAoConfig(quant_type=config, include_embedding=True)
|
||||||
|
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
device_map=self.device,
|
||||||
|
quantization_config=quant_config,
|
||||||
|
)
|
||||||
|
# making sure embedding is quantized
|
||||||
|
self.assertTrue(isinstance(quantized_model.model.embed_tokens.weight, AffineQuantizedTensor))
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||||
|
|
||||||
|
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
||||||
|
|
||||||
|
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||||
|
EXPECTED_OUTPUT = [
|
||||||
|
"What are we having for dinner?\n\nJessica: (smiling)",
|
||||||
|
"What are we having for dinner?\n\nJess: (smiling) I",
|
||||||
|
]
|
||||||
|
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
@require_torchao_version_greater_or_equal("0.11.0")
|
||||||
|
def test_per_module_config_skip(self):
|
||||||
|
linear_config = Int8WeightOnlyConfig()
|
||||||
|
config = AOPerModuleConfig({"_default": linear_config, "model.layers.0.self_attn.q_proj": None})
|
||||||
|
quant_config = TorchAoConfig(quant_type=config)
|
||||||
|
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
device_map=self.device,
|
||||||
|
quantization_config=quant_config,
|
||||||
|
)
|
||||||
|
# making sure `model.layers.0.self_attn.q_proj` is skipped
|
||||||
|
self.assertTrue(not isinstance(quantized_model.model.layers[0].self_attn.q_proj.weight, AffineQuantizedTensor))
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||||
|
|
||||||
|
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
||||||
|
|
||||||
|
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||||
|
EXPECTED_OUTPUT = [
|
||||||
|
"What are we having for dinner?\n\nJessica: (smiling)",
|
||||||
|
"What are we having for dinner?\n\nJess: (smiling) I",
|
||||||
|
]
|
||||||
|
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
class TorchAoGPUTest(TorchAoTest):
|
class TorchAoGPUTest(TorchAoTest):
|
||||||
|
|||||||
Reference in New Issue
Block a user