[docs] Add int4wo + 2:4 sparsity example to TorchAO README (#38592)

* update quantization readme

* update

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
Jesse Cai
2025-06-12 08:17:07 -04:00
committed by GitHub
parent bc68defcac
commit e1812864ab

View File

@@ -38,6 +38,7 @@ torchao supports the [quantization techniques](https://github.com/pytorch/ao/blo
- A8W8 Int8 Dynamic Quantization - A8W8 Int8 Dynamic Quantization
- A16W8 Int8 Weight Only Quantization - A16W8 Int8 Weight Only Quantization
- A16W4 Int4 Weight Only Quantization - A16W4 Int4 Weight Only Quantization
- A16W4 Int4 Weight Only Quantization + 2:4 Sparsity
- Autoquantization - Autoquantization
torchao also supports module level configuration by specifying a dictionary from fully qualified name of module and its corresponding quantization config. This allows skip quantizing certain layers and using different quantization config for different modules. torchao also supports module level configuration by specifying a dictionary from fully qualified name of module and its corresponding quantization config. This allows skip quantizing certain layers and using different quantization config for different modules.
@@ -147,6 +148,37 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
</hfoption> </hfoption>
</hfoptions> </hfoptions>
</hfoption>
<hfoption id="int4-weight-only-24sparse">
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int4WeightOnlyConfig
from torchao.dtypes import MarlinSparseLayout
quant_config = Int4WeightOnlyConfig(layout=MarlinSparseLayout())
quantization_config = TorchAoConfig(quant_type=quant_config)
# Load and quantize the model with sparsity. A sparse checkpoint is needed to accelerate without accuraccy loss
quantized_model = AutoModelForCausalLM.from_pretrained(
"RedHatAI/Sparse-Llama-3.1-8B-2of4",
torch_dtype=torch.float16,
device_map="cuda",
quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained("RedHatAI/Sparse-Llama-3.1-8B-2of4")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
</hfoption>
</hfoptions>
### A100 GPU ### A100 GPU
<hfoptions id="examples-A100-GPU"> <hfoptions id="examples-A100-GPU">
<hfoption id="int8-dynamic-and-weight-only"> <hfoption id="int8-dynamic-and-weight-only">
@@ -215,6 +247,37 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
</hfoption> </hfoption>
</hfoptions> </hfoptions>
</hfoption>
<hfoption id="int4-weight-only-24sparse">
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int4WeightOnlyConfig
from torchao.dtypes import MarlinSparseLayout
quant_config = Int4WeightOnlyConfig(layout=MarlinSparseLayout())
quantization_config = TorchAoConfig(quant_type=quant_config)
# Load and quantize the model with sparsity. A sparse checkpoint is needed to accelerate without accuraccy loss
quantized_model = AutoModelForCausalLM.from_pretrained(
"RedHatAI/Sparse-Llama-3.1-8B-2of4",
torch_dtype=torch.float16,
device_map="cuda",
quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained("RedHatAI/Sparse-Llama-3.1-8B-2of4")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
</hfoption>
</hfoptions>
### CPU ### CPU
<hfoptions id="examples-CPU"> <hfoptions id="examples-CPU">
<hfoption id="int8-dynamic-and-weight-only"> <hfoption id="int8-dynamic-and-weight-only">