Add autoquant support for torchao quantizer (#35503)

* Add autoquant support for torchao quantizer

Summary:
att, also verified that autoquantized model can be saved and loaded:

save: https://gist.github.com/jerryzh168/01d367aaf44dbbbfd4068a4a10a00061
load: https://gist.github.com/jerryzh168/d5c6c401b2abdf18e0b6771341f1525c

Test Plan:
tested locally with above script
model uploaded to https://huggingface.co/jerryzh168/llama3-8b-autoquant

Reviewers:

Subscribers:

Tasks:

Tags:

* add test

* ruff fix

* ruff reformat

* add docs and min_sqnr support

* format

* format

* fix test

* update doc

* format

* remove disable_compile

* format
This commit is contained in:
Jerry Zhang
2025-02-24 06:54:16 -08:00
committed by GitHub
parent 977a61f743
commit 2af272c101
5 changed files with 133 additions and 17 deletions

View File

@@ -22,6 +22,12 @@ pip install --upgrade torch torchao transformers
By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.
## Manually Choose Quantization Types and Settings
`torchao` Provides many commonly used types of quantization, including different dtypes like int4, float8 and different flavors like weight only, dynamic quantization etc., only `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight` are integrated into hugigngface transformers currently, but we can add more when needed.
Users can manually specify the quantization types and settings they want to use:
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
@@ -41,19 +47,14 @@ output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implemen
print(tokenizer.decode(output[0], skip_special_tokens=True))
# benchmark the performance
import torch.utils.benchmark as benchmark
from torch._inductor.utils import do_bench_using_profiling
from typing import Callable
def benchmark_fn(f, *args, **kwargs):
# Manual warmup
for _ in range(5):
f(*args, **kwargs)
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
def benchmark_fn(func: Callable, *args, **kwargs) -> float:
"""Thin wrapper around do_bench_using_profiling"""
no_args = lambda: func(*args, **kwargs)
time = do_bench_using_profiling(no_args)
return time * 1e3
MAX_NEW_TOKENS = 1000
print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
@@ -64,6 +65,47 @@ print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_toke
```
## Automatically Select Quantization Types
`torchao` also provies `autoquant` feature that automatically chooses a quantization type for quantizable layers such as linear based on microbenchmarks of quantizing and compiling a single linear layer.
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
model_name = "meta-llama/Meta-Llama-3-8B"
quantization_config = TorchAoConfig("autoquant", min_sqnr=None)
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
# auto-compile the quantized model with `cache_implementation="static"` to get speedup
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
# Due to some implementation details we are explicitly calling this now, we may refactor our code and remove this in the future
quantized_model.finalize_autoquant()
print(tokenizer.decode(output[0], skip_special_tokens=True))
# benchmark the performance
from torch._inductor.utils import do_bench_using_profiling
from typing import Callable
def benchmark_fn(func: Callable, *args, **kwargs) -> float:
"""Thin wrapper around do_bench_using_profiling"""
no_args = lambda: func(*args, **kwargs)
time = do_bench_using_profiling(no_args)
return time * 1e3
MAX_NEW_TOKENS = 1000
print("autoquantized model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
```
## Serialization and Deserialization
torchao quantization is implemented with [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor), it only work with huggingface non-safetensor serialization and deserialization. It relies on `torch.load(..., weights_only=True)` to avoid arbitrary user code execution during load time and use [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals) to allowlist some known user functions.