Support AOPerModuleConfig and include_embedding (#37802)

* Support `AOPerModuleConfig` and include_embedding

Summary:
This PR adds support per module configuration for torchao
Also added per module quantization examples:

1. Quantizing different layers with different quantization configs
2. Skip quantization for certain layers

Test Plan:
python tests/quantization/torchao_integration/test_torchao.py -k test_include_embedding
python tests/quantization/torchao_integration/test_torchao.py -k test_per_module_config_skip

Reviewers:

Subscribers:

Tasks:

Tags:

* format

* format

* inlcude embedding remove input embedding from module not to convert

* more docs

* Update docs/source/en/quantization/torchao.md

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

* Update src/transformers/quantizers/quantizer_torchao.py

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

* Update src/transformers/quantizers/quantizer_torchao.py

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
Jerry Zhang
2025-04-30 11:16:29 -07:00
committed by GitHub
parent c3aeaa8060
commit 86777b5e2f
4 changed files with 167 additions and 6 deletions

View File

@@ -38,6 +38,13 @@ if is_torchao_available():
AffineQuantizedTensor,
TensorCoreTiledLayout,
)
from torchao.quantization import (
AOPerModuleConfig,
Int8WeightOnlyConfig,
IntxWeightOnlyConfig,
MappingType,
PerAxis,
)
from torchao.quantization.autoquant import AQMixin
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0"):
@@ -193,6 +200,60 @@ class TorchAoTest(unittest.TestCase):
]
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
@require_torchao_version_greater_or_equal("0.11.0")
def test_include_embedding(self):
weight_dtype = torch.int8
granularity = PerAxis(0)
mapping_type = MappingType.ASYMMETRIC
embedding_config = IntxWeightOnlyConfig(
weight_dtype=weight_dtype,
granularity=granularity,
mapping_type=mapping_type,
)
config = AOPerModuleConfig({"_default": None, "model.embed_tokens": embedding_config})
# need set `include_embedding` to True
quant_config = TorchAoConfig(quant_type=config, include_embedding=True)
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=self.device,
quantization_config=quant_config,
)
# making sure embedding is quantized
self.assertTrue(isinstance(quantized_model.model.embed_tokens.weight, AffineQuantizedTensor))
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
EXPECTED_OUTPUT = [
"What are we having for dinner?\n\nJessica: (smiling)",
"What are we having for dinner?\n\nJess: (smiling) I",
]
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
@require_torchao_version_greater_or_equal("0.11.0")
def test_per_module_config_skip(self):
linear_config = Int8WeightOnlyConfig()
config = AOPerModuleConfig({"_default": linear_config, "model.layers.0.self_attn.q_proj": None})
quant_config = TorchAoConfig(quant_type=config)
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=self.device,
quantization_config=quant_config,
)
# making sure `model.layers.0.self_attn.q_proj` is skipped
self.assertTrue(not isinstance(quantized_model.model.layers[0].self_attn.q_proj.weight, AffineQuantizedTensor))
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
EXPECTED_OUTPUT = [
"What are we having for dinner?\n\nJessica: (smiling)",
"What are we having for dinner?\n\nJess: (smiling) I",
]
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
@require_torch_gpu
class TorchAoGPUTest(TorchAoTest):