enable torchao quantization on CPU (#36146)

* enable torchao quantization on CPU

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix int4

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* enable CPU torchao tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix cuda tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix cpu tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix style

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix cuda tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix torchao available

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix torchao available

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix torchao config cannot convert to json

* fix docs

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* rm to_dict to rebase

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* limited torchao version for CPU

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix skip

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* Update src/transformers/testing_utils.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* fix cpu test

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
jiqing-feng
2025-02-25 18:06:52 +08:00
committed by GitHub
parent 401543a825
commit 9d6abf9778
5 changed files with 125 additions and 70 deletions

View File

@@ -59,7 +59,7 @@ Use the table below to help you decide which quantization method to use.
| [HQQ](./hqq.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1/8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
| [optimum-quanto](./quanto.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2/4/8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/optimum-quanto |
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | 🟡 <sub>5</sub> | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
| [torchao](./torchao.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🟡 <sub>5</sub> | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
| [VPTQ](./vptq.md) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
| [SpQR](./spqr.md) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ |
| [FINEGRAINED_FP8](./finegrained_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | |

View File

@@ -22,9 +22,11 @@ pip install --upgrade torch torchao transformers
By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.
## Manually Choose Quantization Types and Settings
`torchao` Provides many commonly used types of quantization, including different dtypes like int4, float8 and different flavors like weight only, dynamic quantization etc., only `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight` are integrated into hugigngface transformers currently, but we can add more when needed.
If you want to run the following codes on CPU even with GPU available, just change `device_map="cpu"` and `quantization_config = TorchAoConfig("int4_weight_only", group_size=128, layout=Int4CPULayout())` where `layout` comes from `from torchao.dtypes import Int4CPULayout` which is only available from torchao 0.8.0 and higher.
Users can manually specify the quantization types and settings they want to use:
@@ -40,7 +42,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device)
# auto-compile the quantized model with `cache_implementation="static"` to get speedup
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
@@ -59,7 +61,7 @@ def benchmark_fn(func: Callable, *args, **kwargs) -> float:
MAX_NEW_TOKENS = 1000
print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)
output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
@@ -122,7 +124,7 @@ quantized_model.save_pretrained(output_dir, safe_serialization=False)
# load quantized model
ckpt_id = "llama3-8b-int4wo-128" # or huggingface hub model id
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(ckpt_id, device_map="cuda")
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(ckpt_id, device_map="auto")
# confirm the speedup