Support AOPerModuleConfig and include_embedding (#37802)

* Support `AOPerModuleConfig` and include_embedding

Summary:
This PR adds support per module configuration for torchao
Also added per module quantization examples:

1. Quantizing different layers with different quantization configs
2. Skip quantization for certain layers

Test Plan:
python tests/quantization/torchao_integration/test_torchao.py -k test_include_embedding
python tests/quantization/torchao_integration/test_torchao.py -k test_per_module_config_skip

Reviewers:

Subscribers:

Tasks:

Tags:

* format

* format

* inlcude embedding remove input embedding from module not to convert

* more docs

* Update docs/source/en/quantization/torchao.md

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

* Update src/transformers/quantizers/quantizer_torchao.py

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

* Update src/transformers/quantizers/quantizer_torchao.py

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
Jerry Zhang
2025-04-30 11:16:29 -07:00
committed by GitHub
parent c3aeaa8060
commit 86777b5e2f
4 changed files with 167 additions and 6 deletions

View File

@@ -40,6 +40,8 @@ torchao supports the [quantization techniques](https://github.com/pytorch/ao/blo
- A16W4 Int4 Weight Only Quantization - A16W4 Int4 Weight Only Quantization
- Autoquantization - Autoquantization
torchao also supports module level configuration by specifying a dictionary from fully qualified name of module and its corresponding quantization config. This allows skip quantizing certain layers and using different quantization config for different modules.
Check the table below to see if your hardware is compatible. Check the table below to see if your hardware is compatible.
@@ -89,7 +91,7 @@ We'll show examples for recommended quantization methods based on hardwares, e.g
```py ```py
import torch import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig
quant_config = Float8DynamicActivationFloat8WeightConfig() quant_config = Float8DynamicActivationFloat8WeightConfig()
# or float8 weight only quantization # or float8 weight only quantization
@@ -149,7 +151,7 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
```py ```py
import torch import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int8DynamicActivationInt8WeightConfig from torchao.quantization import Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig
quant_config = Int8DynamicActivationInt8WeightConfig() quant_config = Int8DynamicActivationInt8WeightConfig()
# or int8 weight only quantization # or int8 weight only quantization
@@ -179,7 +181,7 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
```py ```py
import torch import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import GemliteUIntXWeightOnlyConfig from torchao.quantization import GemliteUIntXWeightOnlyConfig, Int4WeightOnlyConfig
# For batch size N, we recommend gemlite, which may require autotuning # For batch size N, we recommend gemlite, which may require autotuning
# default is 4 bit, 8 bit is also supported by passing `bit_width=8` # default is 4 bit, 8 bit is also supported by passing `bit_width=8`
@@ -216,7 +218,7 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
```py ```py
import torch import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int8DynamicActivationInt8WeightConfig from torchao.quantization import Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig
quant_config = Int8DynamicActivationInt8WeightConfig() quant_config = Int8DynamicActivationInt8WeightConfig()
# quant_config = Int8WeightOnlyConfig() # quant_config = Int8WeightOnlyConfig()
@@ -272,6 +274,74 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
</hfoption> </hfoption>
</hfoptions> </hfoptions>
### Per Module Quantization
#### 1. Skip quantization for certain layers
With `AOPerModuleConfig` we can specify a default configuration for all layers while skipping quantization for certain layers.
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
model_id = "meta-llama/Llama-3.1-8B-Instruct"
from torchao.quantization import Int4WeightOnlyConfig, AOPerModuleConfig
config = Int4WeightOnlyConfig(group_size=128)
# set default to int4 (for linears), and skip quantizing `model.layers.0.self_attn.q_proj`
quant_config = AOPerModuleConfig({"_default": config, "model.layers.0.self_attn.q_proj": None})
quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
# lm_head is not quantized and model.layers.0.self_attn.q_proj is not quantized
print("quantized model:", quantized_model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Manual Testing
prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
output_text = tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
```
#### 2. Quantizing different layers with different quantization configs
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
model_id = "facebook/opt-125m"
from torchao.quantization import Int4WeightOnlyConfig, AOPerModuleConfig, Int8DynamicActivationInt4WeightConfig, IntxWeightOnlyConfig, PerAxis, MappingType
weight_dtype = torch.int8
granularity = PerAxis(0)
mapping_type = MappingType.ASYMMETRIC
embedding_config = IntxWeightOnlyConfig(
weight_dtype=weight_dtype,
granularity=granularity,
mapping_type=mapping_type,
)
linear_config = Int8DynamicActivationInt4WeightConfig(group_size=128)
quant_config = AOPerModuleConfig({"_default": linear_config, "model.decoder.embed_tokens": embedding_config, "model.decoder.embed_positions": None})
# set `include_embedding` to True in order to include embedding in quantization
# when `include_embedding` is True, we'll remove input embedding from `modules_not_to_convert` as well
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
print("quantized model:", quantized_model)
# make sure embedding is quantized
print("embed_tokens weight:", quantized_model.model.decoder.embed_tokens.weight)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Manual Testing
prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128, cache_implementation="static")
output_text = tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
```
### Autoquant ### Autoquant
If you want to automatically choose a quantization type for quantizable layers (`nn.Linear`) you can use the [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API. If you want to automatically choose a quantization type for quantizable layers (`nn.Linear`) you can use the [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API.

View File

@@ -185,6 +185,10 @@ class TorchAoHfQuantizer(HfQuantizer):
self.modules_to_not_convert = self.get_modules_to_not_convert( self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
) )
if self.quantization_config.include_embedding:
input_emb = model.get_input_embeddings()
input_emb_names = [name for name, module in model.named_modules() if id(module) == id(input_emb)]
self.modules_to_not_convert = [x for x in self.modules_to_not_convert if x not in input_emb_names]
return return
def check_quantized_param( def check_quantized_param(
@@ -206,9 +210,12 @@ class TorchAoHfQuantizer(HfQuantizer):
# We don't quantize weights that we offload # We don't quantize weights that we offload
return False return False
else: else:
# we only quantize the weight of nn.Linear # we only quantize the weight of nn.Linear and nn.Embedding
module, tensor_name = get_module_from_name(model, param_name) module, tensor_name = get_module_from_name(model, param_name)
return isinstance(module, torch.nn.Linear) and (tensor_name == "weight") _QUANTIZABLE = [torch.nn.Linear]
if self.quantization_config.include_embedding:
_QUANTIZABLE.append(torch.nn.Embedding)
return isinstance(module, tuple(_QUANTIZABLE)) and (tensor_name == "weight")
def create_quantized_param( def create_quantized_param(
self, self,
@@ -240,6 +247,23 @@ class TorchAoHfQuantizer(HfQuantizer):
module._parameters[tensor_name] = torch.nn.Parameter( module._parameters[tensor_name] = torch.nn.Parameter(
param_value, requires_grad=param_value.requires_grad param_value, requires_grad=param_value.requires_grad
).to(device=target_device) ).to(device=target_device)
# handle AOPerModuleConfig, introduced in torchao 0.11.0+
if self.quantization_config._get_ao_version() > version.Version("0.10.0"):
from torchao.quantization import AOPerModuleConfig
config = self.quantization_config.get_apply_tensor_subclass()
if isinstance(config, AOPerModuleConfig):
module_fqn, _ = param_name.rsplit(".", 1)
c = None
if module_fqn in config.module_fqn_to_config:
c = config.module_fqn_to_config[module_fqn]
else:
c = config.module_fqn_to_config.get("_default", None)
if c is not None:
# filter_fn: not filtering out any modules
quantize_(module, c, filter_fn=lambda x, fqn: True)
return
quantize_(module, self.quantization_config.get_apply_tensor_subclass()) quantize_(module, self.quantization_config.get_apply_tensor_subclass())
def _process_model_after_weight_loading(self, model, **kwargs): def _process_model_after_weight_loading(self, model, **kwargs):

View File

@@ -1554,6 +1554,7 @@ class TorchAoConfig(QuantizationConfigMixin):
quant_type: Union[str, "AOBaseConfig"] # noqa: F821 quant_type: Union[str, "AOBaseConfig"] # noqa: F821
modules_to_not_convert: Optional[List] modules_to_not_convert: Optional[List]
quant_type_kwargs: Dict[str, Any] quant_type_kwargs: Dict[str, Any]
include_embedding: bool
"""This is a config class for torchao quantization/sparsity techniques. """This is a config class for torchao quantization/sparsity techniques.
@@ -1565,6 +1566,9 @@ class TorchAoConfig(QuantizationConfigMixin):
modules_to_not_convert (`list`, *optional*, default to `None`): modules_to_not_convert (`list`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have The list of modules to not quantize, useful for quantizing models that explicitly require to have
some modules left in their original precision. some modules left in their original precision.
inlcude_embedding (`bool`, default to `False`):
Whether to include embedding in quantization or not, input embedding will be removed from
the module_not_to_convert list as well if this flag is set.
kwargs (`Dict[str, Any]`, *optional*): kwargs (`Dict[str, Any]`, *optional*):
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments
`group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in `group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in
@@ -1609,12 +1613,14 @@ class TorchAoConfig(QuantizationConfigMixin):
self, self,
quant_type: Union[str, "AOBaseConfig"], # noqa: F821 quant_type: Union[str, "AOBaseConfig"], # noqa: F821
modules_to_not_convert: Optional[List] = None, modules_to_not_convert: Optional[List] = None,
include_embedding: bool = False,
**kwargs, **kwargs,
): ):
self.quant_method = QuantizationMethod.TORCHAO self.quant_method = QuantizationMethod.TORCHAO
self.quant_type = quant_type self.quant_type = quant_type
self.modules_to_not_convert = modules_to_not_convert self.modules_to_not_convert = modules_to_not_convert
self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs) self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs)
self.include_embedding = include_embedding
self.post_init() self.post_init()
@staticmethod @staticmethod

View File

@@ -38,6 +38,13 @@ if is_torchao_available():
AffineQuantizedTensor, AffineQuantizedTensor,
TensorCoreTiledLayout, TensorCoreTiledLayout,
) )
from torchao.quantization import (
AOPerModuleConfig,
Int8WeightOnlyConfig,
IntxWeightOnlyConfig,
MappingType,
PerAxis,
)
from torchao.quantization.autoquant import AQMixin from torchao.quantization.autoquant import AQMixin
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0"): if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0"):
@@ -193,6 +200,60 @@ class TorchAoTest(unittest.TestCase):
] ]
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT) self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
@require_torchao_version_greater_or_equal("0.11.0")
def test_include_embedding(self):
weight_dtype = torch.int8
granularity = PerAxis(0)
mapping_type = MappingType.ASYMMETRIC
embedding_config = IntxWeightOnlyConfig(
weight_dtype=weight_dtype,
granularity=granularity,
mapping_type=mapping_type,
)
config = AOPerModuleConfig({"_default": None, "model.embed_tokens": embedding_config})
# need set `include_embedding` to True
quant_config = TorchAoConfig(quant_type=config, include_embedding=True)
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=self.device,
quantization_config=quant_config,
)
# making sure embedding is quantized
self.assertTrue(isinstance(quantized_model.model.embed_tokens.weight, AffineQuantizedTensor))
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
EXPECTED_OUTPUT = [
"What are we having for dinner?\n\nJessica: (smiling)",
"What are we having for dinner?\n\nJess: (smiling) I",
]
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
@require_torchao_version_greater_or_equal("0.11.0")
def test_per_module_config_skip(self):
linear_config = Int8WeightOnlyConfig()
config = AOPerModuleConfig({"_default": linear_config, "model.layers.0.self_attn.q_proj": None})
quant_config = TorchAoConfig(quant_type=config)
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=self.device,
quantization_config=quant_config,
)
# making sure `model.layers.0.self_attn.q_proj` is skipped
self.assertTrue(not isinstance(quantized_model.model.layers[0].self_attn.q_proj.weight, AffineQuantizedTensor))
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
EXPECTED_OUTPUT = [
"What are we having for dinner?\n\nJessica: (smiling)",
"What are we having for dinner?\n\nJess: (smiling) I",
]
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
@require_torch_gpu @require_torch_gpu
class TorchAoGPUTest(TorchAoTest): class TorchAoGPUTest(TorchAoTest):