Enable non-safetensor ser/deser for TorchAoConfig quantized model 🔴 (#33456)

* Enable non-safetensor serialization and deserialization for TorchAoConfig quantized model

Summary:
After https://github.com/huggingface/huggingface_hub/pull/2440 we added non-safetensor serialization and deserialization
in huggingface, with this we can now add the support in transformers

Note that we don't plan to add safetensor serialization due to different goals of wrapper tensor subclass and safetensor
see README for more details

Test Plan:
tested locally

Reviewers:

Subscribers:

Tasks:

Tags:

* formatting

* formatting

* minor fix

* formatting

* address comments

* comments

* minor fix

* update doc

* refactor compressed tensor quantizer
This commit is contained in:
Jerry Zhang
2024-09-30 02:30:29 -07:00
committed by GitHub
parent 2e24ee4dfa
commit 4bb49d4e00
15 changed files with 134 additions and 64 deletions

View File

@@ -11,7 +11,7 @@ rendered properly in your Markdown viewer.
# TorchAO
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#without-intrusive-code-changes)
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks).
Before you begin, make sure the following libraries are installed with their latest version:
@@ -21,6 +21,7 @@ pip install --upgrade torch torchao
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
model_name = "meta-llama/Meta-Llama-3-8B"
@@ -40,6 +41,51 @@ quantized_model = torch.compile(quantized_model, mode="max-autotune")
output = quantized_model.generate(**input_ids, max_new_tokens=10)
print(tokenizer.decode(output[0], skip_special_tokens=True))
# benchmark the performance
import torch.utils.benchmark as benchmark
def benchmark_fn(f, *args, **kwargs):
# Manual warmup
for _ in range(5):
f(*args, **kwargs)
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
MAX_NEW_TOKENS = 1000
print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS))
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
bf16_model = torch.compile(bf16_model, mode="max-autotune")
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS))
```
torchao quantization is implemented with tensor subclasses, currently it does not work with huggingface serialization, both the safetensor option and [non-safetensor option](https://github.com/huggingface/transformers/issues/32364), we'll update here with instructions when it's working.
## Serialization and Deserialization
torchao quantization is implemented with [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor), it only work with huggingface non-safetensor serialization and deserialization. It relies on `torch.load(..., weights_only=True)` to avoid arbitrary user code execution during load time and use [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals) to allowlist some known user functions.
The reason why it does not support safe tensor serialization is that wrapper tensor subclass allows maximum flexibility so we want to make sure the effort of supporting new format of quantized Tensor is low, while safe tensor optimizes for maximum safety (no user code execution), it also means we have to make sure to manually support new quantization format.
```py
# save quantized model locally
output_dir = "llama3-8b-int4wo-128"
quantized_model.save_pretrained(output_dir, safe_serialization=False)
# push to huggingface hub
# save_to = "{user_id}/llama3-8b-int4wo-128"
# quantized_model.push_to_hub(save_to, safe_serialization=False)
# load quantized model
ckpt_id = "llama3-8b-int4wo-128" # or huggingface hub model id
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(ckpt_id, device_map="cuda")
# confirm the speedup
loaded_quantized_model = torch.compile(loaded_quantized_model, mode="max-autotune")
print("loaded int4wo-128 model:", benchmark_fn(loaded_quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS))
```