Add autoquant support for torchao quantizer (#35503)
* Add autoquant support for torchao quantizer Summary: att, also verified that autoquantized model can be saved and loaded: save: https://gist.github.com/jerryzh168/01d367aaf44dbbbfd4068a4a10a00061 load: https://gist.github.com/jerryzh168/d5c6c401b2abdf18e0b6771341f1525c Test Plan: tested locally with above script model uploaded to https://huggingface.co/jerryzh168/llama3-8b-autoquant Reviewers: Subscribers: Tasks: Tags: * add test * ruff fix * ruff reformat * add docs and min_sqnr support * format * format * fix test * update doc * format * remove disable_compile * format
This commit is contained in:
@@ -31,10 +31,12 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchao_available():
|
||||
# renamed in torchao 0.7.0, please install the latest torchao
|
||||
from torchao.dtypes import (
|
||||
AffineQuantizedTensor,
|
||||
TensorCoreTiledLayout,
|
||||
)
|
||||
from torchao.quantization.autoquant import AQMixin
|
||||
|
||||
|
||||
def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024):
|
||||
@@ -42,7 +44,12 @@ def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024
|
||||
test_module.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
test_module.assertEqual(weight.quant_min, 0)
|
||||
test_module.assertEqual(weight.quant_max, 15)
|
||||
test_module.assertTrue(isinstance(weight.layout, TensorCoreTiledLayout))
|
||||
test_module.assertTrue(isinstance(weight._layout, TensorCoreTiledLayout))
|
||||
|
||||
|
||||
def check_autoquantized(test_module, qlayer):
|
||||
weight = qlayer.weight
|
||||
test_module.assertTrue(isinstance(weight, AQMixin))
|
||||
|
||||
|
||||
def check_forward(test_module, model, batch_size=1, context_size=1024):
|
||||
@@ -248,6 +255,33 @@ class TorchAoTest(unittest.TestCase):
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
||||
|
||||
def test_autoquant(self):
|
||||
"""
|
||||
Simple LLM model testing autoquant
|
||||
"""
|
||||
quant_config = TorchAoConfig("autoquant")
|
||||
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=torch_device,
|
||||
quantization_config=quant_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
output = quantized_model.generate(
|
||||
**input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static"
|
||||
)
|
||||
quantized_model.finalize_autoquant()
|
||||
|
||||
check_autoquantized(self, quantized_model.model.layers[0].self_attn.v_proj)
|
||||
|
||||
EXPECTED_OUTPUT = 'What are we having for dinner?\n\n10. "Dinner is ready'
|
||||
output = quantized_model.generate(
|
||||
**input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static"
|
||||
)
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torchao
|
||||
|
||||
Reference in New Issue
Block a user